support disable shuffling

Former-commit-id: c7cedc7569973a2879c689637b2923e8b26f1a81
This commit is contained in:
hiyouga 2024-12-19 08:53:21 +00:00
parent d6ce1045f7
commit 95d3c2620b
9 changed files with 139 additions and 12 deletions

3
.gitignore vendored
View File

@ -172,3 +172,6 @@ saves/
output/
wandb/
generated_predictions.jsonl
# unittest
dummy_dir/

View File

@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -119,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
@ -266,7 +273,9 @@ class CustomDPOTrainer(DPOTrainer):
return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
from transformers import Trainer
@ -119,6 +119,9 @@ class CustomKTOTrainer(KTOTrainer):
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self)
@override
@ -245,7 +248,9 @@ class CustomKTOTrainer(KTOTrainer):
return losses, metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -13,8 +13,9 @@
# limitations under the License.
from types import MethodType
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from typing_extensions import override
@ -24,8 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
@ -70,7 +70,16 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs

View File

@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
@ -85,7 +85,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

View File

@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"train_{stage}")
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_dir = os.path.join("output", "llama3_export")
export_dir = os.path.join("output", "dummy_dir/llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)

View File

@ -0,0 +1,81 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List
import pytest
from transformers import DataCollatorWithPadding
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
}
@dataclass
class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: List[Dict[str, Any]] = field(default_factory=list)
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
self.verbose_list.extend(features)
batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
**dataset_module,
**tokenizer_module,
)
trainer.train()
if disable_shuffling:
assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
else:
assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]