From 47a1f73d0f4d95147a0ecde2813e09f520d41195 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 19 Oct 2023 16:17:41 +0800 Subject: [PATCH] fix #1218 Former-commit-id: b301f35bd4a3bf368159c8f5fb4e2736f922115b --- src/llmtuner/dsets/loader.py | 8 ++++++-- src/llmtuner/hparams/data_args.py | 5 +++-- src/llmtuner/tuner/core/parser.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 3b42f17d..826b548c 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -88,7 +88,11 @@ def get_dataset( elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") - stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" - return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=data_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" + ) else: raise ValueError("Unknown mixing strategy.") diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 9d432c56..839dec8f 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -60,7 +60,7 @@ class DataArguments: ) mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( default="concat", - metadata={"help": "Strategy to use in dataset mixing."} + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} ) interleave_probs: Optional[str] = field( default=None, @@ -106,7 +106,8 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") - def init_for_training(self): # support mixing multiple datasets + def init_for_training(self, seed: int): # support mixing multiple datasets + self.seed = seed dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] try: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 50e96bb0..f4da7712 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -88,8 +88,8 @@ def get_train_args( transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) - data_args.init_for_training() + # Check arguments + data_args.init_for_training(training_args.seed) if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.")