mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-11 23:52:50 +08:00
Merge pull request #6388 from hiyouga/hiyouga/shuffle_control
[trainer] support disable shuffling Former-commit-id: ffbb4dbdb09ba799af1800c78b2e9d669bccd24b
This commit is contained in:
commit
af9ef037dd
3
.gitignore
vendored
3
.gitignore
vendored
@ -172,3 +172,6 @@ saves/
|
|||||||
output/
|
output/
|
||||||
wandb/
|
wandb/
|
||||||
generated_predictions.jsonl
|
generated_predictions.jsonl
|
||||||
|
|
||||||
|
# unittest
|
||||||
|
dummy_dir/
|
||||||
|
@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
|
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
|
||||||
)
|
)
|
||||||
|
disable_shuffling: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||||
|
)
|
||||||
plot_loss: bool = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
@ -19,7 +19,7 @@ import warnings
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -119,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||||
|
if self.finetuning_args.disable_shuffling:
|
||||||
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
r"""
|
r"""
|
||||||
@ -266,7 +273,9 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
return losses.mean(), metrics
|
return losses.mean(), metrics
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def compute_loss(
|
||||||
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value for transformers 4.46.0.
|
Fixes the loss value for transformers 4.46.0.
|
||||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
@ -19,7 +19,7 @@ import warnings
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@ -119,6 +119,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
r"""
|
r"""
|
||||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||||
"""
|
"""
|
||||||
|
if self.finetuning_args.disable_shuffling:
|
||||||
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -245,7 +248,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
return losses, metrics
|
return losses, metrics
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def compute_loss(
|
||||||
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value for transformers 4.46.0.
|
Fixes the loss value for transformers 4.46.0.
|
||||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
@ -13,8 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -24,8 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
from transformers import PreTrainedModel, ProcessorMixin
|
||||||
from transformers import ProcessorMixin
|
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
@ -70,7 +70,16 @@ class CustomTrainer(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||||
|
if self.finetuning_args.disable_shuffling:
|
||||||
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(
|
||||||
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value for transformers 4.46.0.
|
Fixes the loss value for transformers 4.46.0.
|
||||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
|
|||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||||
|
if self.finetuning_args.disable_shuffling:
|
||||||
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
|
@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
@ -85,7 +85,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||||
|
if self.finetuning_args.disable_shuffling:
|
||||||
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(
|
||||||
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value for transformers 4.46.0.
|
Fixes the loss value for transformers 4.46.0.
|
||||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run_exp(stage: str, dataset: str):
|
def test_run_exp(stage: str, dataset: str):
|
||||||
output_dir = os.path.join("output", f"train_{stage}")
|
output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
|
||||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||||
assert os.path.exists(output_dir)
|
assert os.path.exists(output_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_export():
|
def test_export():
|
||||||
export_dir = os.path.join("output", "llama3_export")
|
export_dir = os.path.join("output", "dummy_dir/llama3_export")
|
||||||
export_model({"export_dir": export_dir, **INFER_ARGS})
|
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||||
assert os.path.exists(export_dir)
|
assert os.path.exists(export_dir)
|
||||||
|
81
tests/train/test_sft_trainer.py
Normal file
81
tests/train/test_sft_trainer.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
|
from llamafactory.hparams import get_train_args
|
||||||
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
TRAIN_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"stage": "sft",
|
||||||
|
"do_train": True,
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||||
|
"dataset_dir": "ONLINE",
|
||||||
|
"template": "llama3",
|
||||||
|
"cutoff_len": 1024,
|
||||||
|
"overwrite_cache": False,
|
||||||
|
"overwrite_output_dir": True,
|
||||||
|
"per_device_train_batch_size": 1,
|
||||||
|
"max_steps": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorWithVerbose(DataCollatorWithPadding):
|
||||||
|
verbose_list: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
self.verbose_list.extend(features)
|
||||||
|
batch = super().__call__(features)
|
||||||
|
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("disable_shuffling", [False, True])
|
||||||
|
def test_shuffle(disable_shuffling: bool):
|
||||||
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||||
|
{"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
|
||||||
|
)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
|
||||||
|
trainer = CustomSeq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
**dataset_module,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
if disable_shuffling:
|
||||||
|
assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
|
||||||
|
else:
|
||||||
|
assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]
|
Loading…
x
Reference in New Issue
Block a user