mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support disable shuffling
Former-commit-id: 9d8c35fd6b838ede0bd6827c6c6121f2cba2b11b
This commit is contained in:
		
							parent
							
								
									eca06531c3
								
							
						
					
					
						commit
						01eeae50b5
					
				
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -172,3 +172,6 @@ saves/
 | 
			
		||||
output/
 | 
			
		||||
wandb/
 | 
			
		||||
generated_predictions.jsonl
 | 
			
		||||
 | 
			
		||||
# unittest
 | 
			
		||||
dummy_dir/
 | 
			
		||||
 | 
			
		||||
@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
 | 
			
		||||
    )
 | 
			
		||||
    disable_shuffling: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to disable the shuffling of the training set."},
 | 
			
		||||
    )
 | 
			
		||||
    plot_loss: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to save the training loss curves."},
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ import warnings
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
@ -119,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_batch_samples(self, epoch_iterator, num_batches):
 | 
			
		||||
        r"""
 | 
			
		||||
@ -266,7 +273,9 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        return losses.mean(), metrics
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value for transformers 4.46.0.
 | 
			
		||||
        https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ import warnings
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
@ -119,6 +119,9 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        r"""
 | 
			
		||||
        Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
 | 
			
		||||
        """
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return Trainer._get_train_sampler(self)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
@ -245,7 +248,9 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        return losses, metrics
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value for transformers 4.46.0.
 | 
			
		||||
        https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,9 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
@ -24,8 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers import PreTrainedModel, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
@ -70,7 +70,16 @@ class CustomTrainer(Trainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value for transformers 4.46.0.
 | 
			
		||||
        https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
 | 
			
		||||
 | 
			
		||||
@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from torch.utils.data import Dataset
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.trainer import PredictionOutput
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
@ -85,7 +85,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value for transformers 4.46.0.
 | 
			
		||||
        https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
 | 
			
		||||
 | 
			
		||||
@ -60,12 +60,12 @@ OS_NAME = os.getenv("OS_NAME", "")
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_run_exp(stage: str, dataset: str):
 | 
			
		||||
    output_dir = os.path.join("output", f"train_{stage}")
 | 
			
		||||
    output_dir = os.path.join("output", f"dummy_dir/train_{stage}")
 | 
			
		||||
    run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
 | 
			
		||||
    assert os.path.exists(output_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_export():
 | 
			
		||||
    export_dir = os.path.join("output", "llama3_export")
 | 
			
		||||
    export_dir = os.path.join("output", "dummy_dir/llama3_export")
 | 
			
		||||
    export_model({"export_dir": export_dir, **INFER_ARGS})
 | 
			
		||||
    assert os.path.exists(export_dir)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										81
									
								
								tests/train/test_sft_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								tests/train/test_sft_trainer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,81 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, List
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from transformers import DataCollatorWithPadding
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.hparams import get_train_args
 | 
			
		||||
from llamafactory.model import load_model, load_tokenizer
 | 
			
		||||
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
 | 
			
		||||
 | 
			
		||||
TRAIN_ARGS = {
 | 
			
		||||
    "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
    "stage": "sft",
 | 
			
		||||
    "do_train": True,
 | 
			
		||||
    "finetuning_type": "lora",
 | 
			
		||||
    "dataset": "llamafactory/tiny-supervised-dataset",
 | 
			
		||||
    "dataset_dir": "ONLINE",
 | 
			
		||||
    "template": "llama3",
 | 
			
		||||
    "cutoff_len": 1024,
 | 
			
		||||
    "overwrite_cache": False,
 | 
			
		||||
    "overwrite_output_dir": True,
 | 
			
		||||
    "per_device_train_batch_size": 1,
 | 
			
		||||
    "max_steps": 1,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataCollatorWithVerbose(DataCollatorWithPadding):
 | 
			
		||||
    verbose_list: List[Dict[str, Any]] = field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
 | 
			
		||||
        self.verbose_list.extend(features)
 | 
			
		||||
        batch = super().__call__(features)
 | 
			
		||||
        return {k: v[:, :1] for k, v in batch.items()}  # truncate input length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("disable_shuffling", [False, True])
 | 
			
		||||
def test_shuffle(disable_shuffling: bool):
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, _ = get_train_args(
 | 
			
		||||
        {"output_dir": f"dummy_dir/{disable_shuffling}", "disable_shuffling": disable_shuffling, **TRAIN_ARGS}
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
    data_collator = DataCollatorWithVerbose(tokenizer=tokenizer)
 | 
			
		||||
    trainer = CustomSeq2SeqTrainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
        args=training_args,
 | 
			
		||||
        finetuning_args=finetuning_args,
 | 
			
		||||
        data_collator=data_collator,
 | 
			
		||||
        **dataset_module,
 | 
			
		||||
        **tokenizer_module,
 | 
			
		||||
    )
 | 
			
		||||
    trainer.train()
 | 
			
		||||
    if disable_shuffling:
 | 
			
		||||
        assert data_collator.verbose_list[0]["input_ids"] == dataset_module["train_dataset"][0]["input_ids"]
 | 
			
		||||
    else:
 | 
			
		||||
        assert data_collator.verbose_list[0]["input_ids"] != dataset_module["train_dataset"][0]["input_ids"]
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user