support interleave probs

Former-commit-id: 69744c17e8180e0ad549b57d575454724b820d01
This commit is contained in:
hiyouga 2023-08-04 21:27:35 +08:00
parent 44823ec2c7
commit b32ed1d7be
3 changed files with 17 additions and 7 deletions

View File

@ -111,6 +111,6 @@ def get_dataset(
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy) return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy.")

View File

@ -25,15 +25,17 @@ SUPPORTED_MODELS = {
"BLOOMZ-560M": "bigscience/bloomz-560m", "BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b", "BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b", "Falcon-7B": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b", "Falcon-40B": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B", "Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b", "InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b" "InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat"
} }
DEFAULT_MODULE = { DEFAULT_MODULE = {
@ -43,5 +45,6 @@ DEFAULT_MODULE = {
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",
"Baichuan": "W_pack", "Baichuan": "W_pack",
"InternLM": "q_proj,v_proj" "InternLM": "q_proj,v_proj",
"Qwen": "c_attn"
} }

View File

@ -54,6 +54,10 @@ class DataArguments:
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing."} metadata={"help": "Strategy to use in dataset mixing."}
) )
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
)
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."}
@ -103,6 +107,9 @@ class DataArguments:
else: else:
prefix_list = [None] * len(dataset_names) prefix_list = [None] * len(dataset_names)
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = [] self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names): for i, name in enumerate(dataset_names):
if name not in dataset_info: if name not in dataset_info: