From 95d3c2620b7b8180fdf9eb6c769af431f99e2cf5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 19 Dec 2024 08:53:21 +0000 Subject: [PATCH] support disable shuffling Former-commit-id: c7cedc7569973a2879c689637b2923e8b26f1a81 --- .gitignore | 3 + src/llamafactory/hparams/finetuning_args.py | 4 + src/llamafactory/train/dpo/trainer.py | 13 +++- src/llamafactory/train/kto/trainer.py | 9 ++- src/llamafactory/train/pt/trainer.py | 17 ++++- src/llamafactory/train/rm/trainer.py | 7 ++ src/llamafactory/train/sft/trainer.py | 13 +++- tests/e2e/test_train.py | 4 +- tests/train/test_sft_trainer.py | 81 +++++++++++++++++++++ 9 files changed, 139 insertions(+), 12 deletions(-) create mode 100644 tests/train/test_sft_trainer.py diff --git a/.gitignore b/.gitignore index 88c36ca2..5e6121d3 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,6 @@ saves/ output/ wandb/ generated_predictions.jsonl + +# unittest +dummy_dir/ diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 8cfea728..6d350a73 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."}, ) + disable_shuffling: bool = field( + default=False, + metadata={"help": "Whether or not to disable the shuffling of the training set."}, + ) plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 7e76dee2..330de386 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -19,7 +19,7 @@ import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -119,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + @override + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + @override def get_batch_samples(self, epoch_iterator, num_batches): r""" @@ -266,7 +273,9 @@ class CustomDPOTrainer(DPOTrainer): return losses.mean(), metrics @override - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value for transformers 4.46.0. https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index e22b16a4..3d007ae7 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -19,7 +19,7 @@ import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch from transformers import Trainer @@ -119,6 +119,9 @@ class CustomKTOTrainer(KTOTrainer): r""" Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. """ + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + return Trainer._get_train_sampler(self) @override @@ -245,7 +248,9 @@ class CustomKTOTrainer(KTOTrainer): return losses, metrics @override - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value for transformers 4.46.0. https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 37dcadfd..2e77ba43 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -13,8 +13,9 @@ # limitations under the License. from types import MethodType -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import torch from transformers import Trainer from typing_extensions import override @@ -24,8 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: - import torch - from transformers import ProcessorMixin + from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -70,7 +70,16 @@ class CustomTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) @override - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + + @override + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value for transformers 4.46.0. https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index bccfdef5..4b740837 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + @override + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + @override def compute_loss( self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0f118bbb..3136d7fd 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: from torch.utils.data import Dataset - from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.trainer import PredictionOutput from ...hparams import FinetuningArguments @@ -85,7 +85,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): return super().create_scheduler(num_training_steps, optimizer) @override - def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + + @override + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value for transformers 4.46.0. https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 diff --git a/tests/e2e/test_train.py b/tests/e2e/test_train.py index 71cda495..d1eae617 100644 --- a/tests/e2e/test_train.py +++ b/tests/e2e/test_train.py @@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "") ], ) def test_run_exp(stage: str, dataset: str): - output_dir = os.path.join("output", f"train_{stage}") + output_dir = os.path.join("output", f"dummy_dir/train_{stage}") run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS}) assert os.path.exists(output_dir) def test_export(): - export_dir = os.path.join("output", "llama3_export") + export_dir = os.path.join("output", "dummy_dir/llama3_export") export_model({"export_dir": export_dir, **INFER_ARGS}) assert os.path.exists(export_dir) diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py new file mode 100644 index 00000000..e4391c10 --- /dev/null +++ b/tests/train/test_sft_trainer.py @@ -0,0 +1,81 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List + +import pytest +from transformers import DataCollatorWithPadding + +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer +from llamafactory.hparams import get_train_args +from llamafactory.model import load_model, load_tokenizer +from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer + + +DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") + +TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": False, + "overwrite_output_dir": True, + "per_device_train_batch_size": 1, + "max_steps": 1, +} + + +@dataclass +class DataCollatorWithVerbose(DataCollatorWithPadding): + verbose_list: List[Dict[str, Any]] = field(default_factory=list) + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + self.verbose_list.extend(features) + batch = super().__call__(features) + return {k: v[:, :1] for k, v in batch.items()} # truncate input length + + +@pytest.mark.parametrize("disable_shuffling", [False, True]) +def test_shuffle(disable_shuffling: bool): + model_args, data_args, training_args, finetuning_args, _ = get_train_args( + {"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + data_collator = DataCollatorWithVerbose(tokenizer=tokenizer) + trainer = CustomSeq2SeqTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + **dataset_module, + **tokenizer_module, + ) + trainer.train() + if disable_shuffling: + assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"] + else: + assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]