From b32ed1d7be6b695851e0468386ec715c9dc8a0e9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 4 Aug 2023 21:27:35 +0800 Subject: [PATCH] support interleave probs Former-commit-id: 69744c17e8180e0ad549b57d575454724b820d01 --- src/llmtuner/dsets/loader.py | 2 +- src/llmtuner/extras/constants.py | 15 +++++++++------ src/llmtuner/hparams/data_args.py | 7 +++++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index a51b9024..90a4212f 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -111,6 +111,6 @@ def get_dataset( if not data_args.streaming: logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" - return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy) + return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) else: raise ValueError("Unknown mixing strategy.") diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index d2118408..6f6dbdd7 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -25,15 +25,17 @@ SUPPORTED_MODELS = { "BLOOMZ-560M": "bigscience/bloomz-560m", "BLOOMZ-3B": "bigscience/bloomz-3b", "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", - "Falcon-7B-Base": "tiiuae/falcon-7b", + "Falcon-7B": "tiiuae/falcon-7b", "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", - "Falcon-40B-Base": "tiiuae/falcon-40b", + "Falcon-40B": "tiiuae/falcon-40b", "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", "Baichuan-7B": "baichuan-inc/Baichuan-7B", - "Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base", + "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", - "InternLM-7B-Base": "internlm/internlm-7b", - "InternLM-7B-Chat": "internlm/internlm-chat-7b" + "InternLM-7B": "internlm/internlm-7b", + "InternLM-7B-Chat": "internlm/internlm-chat-7b", + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat" } DEFAULT_MODULE = { @@ -43,5 +45,6 @@ DEFAULT_MODULE = { "BLOOMZ": "query_key_value", "Falcon": "query_key_value", "Baichuan": "W_pack", - "InternLM": "q_proj,v_proj" + "InternLM": "q_proj,v_proj", + "Qwen": "c_attn" } diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index ce88d4d9..60945b60 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -54,6 +54,10 @@ class DataArguments: default="concat", metadata={"help": "Strategy to use in dataset mixing."} ) + interleave_probs: Optional[str] = field( + default=None, + metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} + ) overwrite_cache: Optional[bool] = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets."} @@ -103,6 +107,9 @@ class DataArguments: else: prefix_list = [None] * len(dataset_names) + if self.interleave_probs is not None: + self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] + self.dataset_list: List[DatasetAttr] = [] for i, name in enumerate(dataset_names): if name not in dataset_info: