mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
parent
0622f8a3a3
commit
4930118761
Binary file not shown.
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 146 KiB |
@ -88,7 +88,11 @@ def get_dataset(
|
|||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
return interleave_datasets(
|
||||||
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
|
datasets=all_datasets,
|
||||||
|
probabilities=data_args.interleave_probs,
|
||||||
|
seed=data_args.seed,
|
||||||
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
@ -60,7 +60,7 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing."}
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -106,7 +106,8 @@ class DataArguments:
|
|||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||||
|
self.seed = seed
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||||
|
@ -88,8 +88,8 @@ def get_train_args(
|
|||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
# Check arguments
|
||||||
data_args.init_for_training()
|
data_args.init_for_training(training_args.seed)
|
||||||
|
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user