Merge pull request #6388 from hiyouga/hiyouga/shuffle_control

[trainer] support disable shuffling

Former-commit-id: ffbb4dbdb09ba799af1800c78b2e9d669bccd24b
This commit is contained in:
hoshi-hiyouga 2024-12-19 17:00:12 +08:00 committed by GitHub
commit af9ef037dd
9 changed files with 139 additions and 12 deletions

3
.gitignore vendored
View File

@ -172,3 +172,6 @@ saves/
output/ output/
wandb/ wandb/
generated_predictions.jsonl generated_predictions.jsonl
# unittest
dummy_dir/

View File

@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."}, metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
) )
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -119,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override @override
def get_batch_samples(self, epoch_iterator, num_batches): def get_batch_samples(self, epoch_iterator, num_batches):
r""" r"""
@ -266,7 +273,9 @@ class CustomDPOTrainer(DPOTrainer):
return losses.mean(), metrics return losses.mean(), metrics
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -119,6 +119,9 @@ class CustomKTOTrainer(KTOTrainer):
r""" r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
""" """
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override @override
@ -245,7 +248,9 @@ class CustomKTOTrainer(KTOTrainer):
return losses, metrics return losses, metrics
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
@ -24,8 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
import torch from transformers import PreTrainedModel, ProcessorMixin
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
@ -70,7 +70,16 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs

View File

@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
@ -85,7 +85,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
], ],
) )
def test_run_exp(stage: str, dataset: str): def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"train_{stage}") output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS}) run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir) assert os.path.exists(output_dir)
def test_export(): def test_export():
export_dir = os.path.join("output", "llama3_export") export_dir = os.path.join("output", "dummy_dir/llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS}) export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir) assert os.path.exists(export_dir)

View File

@ -0,0 +1,81 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List
import pytest
from transformers import DataCollatorWithPadding
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
}
@dataclass
class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: List[Dict[str, Any]] = field(default_factory=list)
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
self.verbose_list.extend(features)
batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
**dataset_module,
**tokenizer_module,
)
trainer.train()
if disable_shuffling:
assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
else:
assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]