From b77c745b1ac1e8c574e1ca09ff403a728c42d719 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 2 Nov 2023 23:10:04 +0800 Subject: [PATCH] support sharegpt format, add datasets Former-commit-id: 202daf8987ccb7523be03ca535b572b5c9e65994 --- README.md | 94 +++++++++++++++++++------------ README_zh.md | 94 +++++++++++++++++++------------ data/README.md | 2 + src/llmtuner/dsets/loader.py | 74 +++++++++++++++++++----- src/llmtuner/dsets/preprocess.py | 23 ++++---- src/llmtuner/hparams/data_args.py | 2 + 6 files changed, 192 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 7eb2d05f..2a87b659 100644 --- a/README.md +++ b/README.md @@ -86,39 +86,61 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o ## Provided Datasets -- For pre-training: - - [Wiki Demo (en)](data/wiki_demo.txt) - - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) -- For supervised fine-tuning: - - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - - [Self-cognition (zh)](data/self_cognition.json) - - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) - - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) - - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) - - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - - [UltraChat (en)](https://github.com/thunlp/UltraChat) - - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) -- For reward modeling or DPO training: - - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +
Pre-training datasets + +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
Supervised fine-tuning datasets + +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Self-cognition (zh)](data/self_cognition.json) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) + +
+ +
Preference datasets + +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + +
Please refer to [data/README.md](data/README.md) for details. @@ -135,8 +157,8 @@ huggingface-cli login - 🤗Transformers, Datasets, Accelerate, PEFT and TRL - sentencepiece, protobuf and tiktoken - fire, jieba, rouge-chinese and nltk (used at evaluation and predict) -- gradio and matplotlib (used in web_demo.py) -- uvicorn, fastapi and sse-starlette (used in api_demo.py) +- gradio and matplotlib (used in web UI) +- uvicorn, fastapi and sse-starlette (used in API) And **powerful GPUs**! @@ -144,7 +166,7 @@ And **powerful GPUs**! ### Data Preparation (optional) -Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset. +Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset. > [!NOTE] > Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`. diff --git a/README_zh.md b/README_zh.md index ed8be250..2e40c63f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -86,41 +86,63 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ## 数据集 -- 用于预训练: - - [Wiki Demo (en)](data/wiki_demo.txt) - - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) -- 用于指令监督微调: - - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - - [Self-cognition (zh)](data/self_cognition.json) - - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) - - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) - - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) - - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - - [UltraChat (en)](https://github.com/thunlp/UltraChat) - - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) -- 用于训练奖励模型或 DPO 训练: - - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +
预训练数据集 -使用方法请参考 [data/README.md](data/README_zh.md) 文件。 +- [Wiki Demo (en)](data/wiki_demo.txt) +- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) +- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2) +- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) +- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) +- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) +- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) +- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) + +
+ +
指令微调数据集 + +- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) +- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) +- [Self-cognition (zh)](data/self_cognition.json) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) +- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) +- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) +- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) +- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) +- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) +- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) +- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) +- [UltraChat (en)](https://github.com/thunlp/UltraChat) +- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima) +- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) +- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) +- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) +- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) +- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) +- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) +- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k) +- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) +- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) +- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct) +- [LMSYS Chat (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) + +
+ +
偏好数据集 + +- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) +- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) +- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + +
+ +使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。 部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。 @@ -144,10 +166,10 @@ huggingface-cli login ### 数据准备(可跳过) -关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。 +关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。 > [!NOTE] -> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。 +> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。 ### 环境搭建(可跳过) diff --git a/data/README.md b/data/README.md index 3be493b2..8a11561e 100644 --- a/data/README.md +++ b/data/README.md @@ -6,7 +6,9 @@ If you are using a custom dataset, please provide your dataset definition in the "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional)", + "subset": "", "ranking": "whether the examples contains ranked responses or not. (default: false)", + "formatting": "", "columns": { "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", "query": "the name of the column in the datasets containing the queries. (default: input)", diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index fe88ce50..46db294a 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset @@ -26,22 +26,23 @@ def get_dataset( if dataset_attr.load_from == "hf_hub": data_path = dataset_attr.dataset_name + data_name = dataset_attr.subset data_files = None elif dataset_attr.load_from == "script": data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_name = dataset_attr.subset data_files = None elif dataset_attr.load_from == "file": - data_path = None + data_path, data_name = None, None data_files: List[str] = [] - - if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory + if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) if data_path is None: data_path = EXT2TYPE.get(file_name.split(".")[-1], None) else: - assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." - elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file + assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." + elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) else: @@ -53,7 +54,8 @@ def get_dataset( raise NotImplementedError dataset = load_dataset( - data_path, + path=data_path, + name=data_name, data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, @@ -61,15 +63,59 @@ def get_dataset( use_auth_token=True if model_args.use_auth_token else None ) - if max_samples is not None: - max_samples_temp = min(len(dataset), max_samples) - dataset = dataset.select(range(max_samples_temp)) + if max_samples is not None: # truncate dataset + dataset = dataset.select(range(min(len(dataset), max_samples))) - # TODO: adapt to the sharegpt format + def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + # convert dataset from sharegpt format to alpaca format + outputs = {"prompt": [], "query": [], "response": [], "history": []} + for msg_list in examples[dataset_attr.prompt]: + msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2 + if len(msg_list) == 0: + continue - for column_name in ["prompt", "query", "response", "history"]: # align datasets - if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: - dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) + msg_pairs = [] + user_role, assistant_role = None, None + for idx in range(0, len(msg_list), 2): + if user_role is None and assistant_role is None: + user_role = msg_list[idx][dataset_attr.query] + assistant_role = msg_list[idx + 1][dataset_attr.query] + else: + if ( + msg_list[idx][dataset_attr.query] != user_role + or msg_list[idx+1][dataset_attr.query] != assistant_role + ): + raise ValueError("Only accepts conversation in u/a/u/a/u/a order.") + msg_pairs.append((msg_list[idx][dataset_attr.response], msg_list[idx + 1][dataset_attr.response])) + + if len(msg_pairs) != 0: + outputs["prompt"].append(msg_pairs[-1][0]) + outputs["query"].append("") + outputs["response"].append(msg_pairs[-1][1]) + outputs["history"].append(msg_pairs[:-1]) + + return outputs + + if dataset_attr.formatting == "sharegpt": # convert format + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache), + desc="Converting format of dataset" + ) + + dataset = dataset.map( + convert_format, + batched=True, + remove_columns=column_names, + **kwargs + ) + else: + for column_name in ["prompt", "query", "response", "history"]: # align dataset + if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: + dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) if dataset_attr.system_prompt: # add system prompt system_prompt = dataset_attr.system_prompt diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 18f01db1..0484b78e 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -39,7 +39,7 @@ def preprocess_dataset( system = examples["system"][i] if "system" in examples else None yield query, response, history, system - def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) kwargs = dict(allowed_special="all") @@ -62,7 +62,7 @@ def preprocess_dataset( } return result - def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -108,7 +108,7 @@ def preprocess_dataset( return model_inputs - def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -145,7 +145,7 @@ def preprocess_dataset( return model_inputs - def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -169,10 +169,10 @@ def preprocess_dataset( return model_inputs - def preprocess_pairwise_dataset(examples): + def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - for query, response, history, system in construct_example(examples): + for query, response, history, system in construct_example(examples): if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1): continue @@ -197,9 +197,10 @@ def preprocess_dataset( model_inputs["prompt_ids"].append(prompt_ids) model_inputs["chosen_ids"].append(chosen_ids) model_inputs["rejected_ids"].append(rejected_ids) + return model_inputs - def print_supervised_dataset_example(example): + def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) @@ -207,7 +208,7 @@ def preprocess_dataset( tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) )) - def print_pairwise_dataset_example(example): + def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None: print("prompt_ids:\n{}".format(example["prompt_ids"])) print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) print("chosen_ids:\n{}".format(example["chosen_ids"])) @@ -215,7 +216,7 @@ def preprocess_dataset( print("rejected_ids:\n{}".format(example["rejected_ids"])) print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) - def print_unsupervised_dataset_example(example): + def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None: print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) @@ -242,13 +243,13 @@ def preprocess_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, + load_from_cache_file=(not data_args.overwrite_cache), desc="Running tokenizer on dataset" ) dataset = dataset.map( preprocess_func, - batched=True, + batched=True, remove_columns=column_names, **kwargs ) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 49b86345..2f4cda38 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -11,6 +11,7 @@ class DatasetAttr: dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None + subset: Optional[str] = None ranking: Optional[bool] = False formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" @@ -155,6 +156,7 @@ class DataArguments: dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None) + dataset_attr.subset = dataset_info[name].get("subset", None) dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.system_prompt = prompt_list[i]