From 54b8ce7b63580d4dc0851f13719e191e351546bc Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 28 May 2023 18:09:04 +0800 Subject: [PATCH] Initial commit Former-commit-id: 769c6ab56be0c9d26e9289f61ac54a4068d935c1 --- .gitattributes | 2 + LICENSE | 201 ++++++++ README.md | 29 ++ data/README.md | 53 ++ data/alpaca_data_en_52k.json.REMOVED.git-id | 1 + data/alpaca_data_zh_51k.json.REMOVED.git-id | 1 + data/alpaca_gpt4_data_en.json.REMOVED.git-id | 1 + data/alpaca_gpt4_data_zh.json.REMOVED.git-id | 1 + ...omparison_gpt4_data_en.json.REMOVED.git-id | 1 + ...omparison_gpt4_data_zh.json.REMOVED.git-id | 1 + data/dataset_info.json | 97 ++++ data/example_dataset/example_dataset.py | 46 ++ data/example_dataset/examples.json | 20 + data/hh_rlhf_en/hh_rlhf_en.py | 97 ++++ data/ultra_chat/ultra_chat.py | 76 +++ src/__init__.py | 0 src/cli_demo.py | 66 +++ src/export_model.py | 23 + src/train_ppo.py | 80 +++ src/train_rm.py | 72 +++ src/train_sft.py | 95 ++++ src/utils/__init__.py | 15 + src/utils/common.py | 459 ++++++++++++++++++ src/utils/config.py | 212 ++++++++ src/utils/data_collator.py | 67 +++ src/utils/other.py | 205 ++++++++ src/utils/pairwise.py | 51 ++ src/utils/peft_trainer.py | 78 +++ src/utils/ppo.py | 241 +++++++++ src/utils/seq2seq.py | 96 ++++ src/web_demo.py | 129 +++++ 31 files changed, 2516 insertions(+) create mode 100644 .gitattributes create mode 100644 LICENSE create mode 100644 README.md create mode 100644 data/README.md create mode 100644 data/alpaca_data_en_52k.json.REMOVED.git-id create mode 100644 data/alpaca_data_zh_51k.json.REMOVED.git-id create mode 100644 data/alpaca_gpt4_data_en.json.REMOVED.git-id create mode 100644 data/alpaca_gpt4_data_zh.json.REMOVED.git-id create mode 100644 data/comparison_gpt4_data_en.json.REMOVED.git-id create mode 100644 data/comparison_gpt4_data_zh.json.REMOVED.git-id create mode 100644 data/dataset_info.json create mode 100644 data/example_dataset/example_dataset.py create mode 100644 data/example_dataset/examples.json create mode 100644 data/hh_rlhf_en/hh_rlhf_en.py create mode 100644 data/ultra_chat/ultra_chat.py create mode 100644 src/__init__.py create mode 100644 src/cli_demo.py create mode 100644 src/export_model.py create mode 100644 src/train_ppo.py create mode 100644 src/train_rm.py create mode 100644 src/train_sft.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/common.py create mode 100644 src/utils/config.py create mode 100644 src/utils/data_collator.py create mode 100644 src/utils/other.py create mode 100644 src/utils/pairwise.py create mode 100644 src/utils/peft_trainer.py create mode 100644 src/utils/ppo.py create mode 100644 src/utils/seq2seq.py create mode 100644 src/web_demo.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..dfe07704 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..b09cd785 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 00000000..0dd4e45c --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# LLaMA Efficient Tuning + +1. Download the weights of the LLaMA models. +2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) + +```python +python convert_llama_weights_to_hf.py \ + --input_dir path_to_llama_weights --model_size 7B --output_dir llama_7b +``` + +3. Fine-tune the LLaMA models. + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ + --model_name_or_path llama_7b \ + --do_train \ + --dataset alpaca_gpt4_zh \ + --finetuning_type lora \ + --output_dir path_to_sft_checkpoint \ + --overwrite_cache \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 2 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 100 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --fp16 +``` diff --git a/data/README.md b/data/README.md new file mode 100644 index 00000000..6a30cb2b --- /dev/null +++ b/data/README.md @@ -0,0 +1,53 @@ +Data format in `dataset_info.json`: +```json +"dataset_name": { + "hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)", + "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", + "file_name": "the name of the dataset file in the this directory. (required if above are not specified)", + "file_sha1": "the SHA-1 hash value of the dataset file. (optional)", + "columns": { + "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", + "query": "the name of the column in the datasets containing the queries. (default: input)", + "response": "the name of the column in the datasets containing the responses. (default: output)", + "history": "the name of the column in the datasets containing the history of chat. (default: None)" + } +} +``` + +`dataset_info.json` 中的数据集定义格式: +```json +"数据集名称": { + "hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)", + "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)", + "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", + "file_sha1": "数据集文件的SHA-1哈希值(可选)", + "columns": { + "prompt": "数据集代表提示词的表头名称(默认:instruction)", + "query": "数据集代表请求的表头名称(默认:input)", + "response": "数据集代表回答的表头名称(默认:output)", + "history": "数据集代表历史对话的表头名称(默认:None)" + } +} +``` + +部分预置数据集简介: + +| 数据集名称 | 规模 | 描述 | +| --- | --- | --- | +| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 | +| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 | +| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 | +| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 | +| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 | +| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 | +| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 | +| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 | +| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 | +| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 | +| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 | +| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 | +| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 | +| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 | +| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 | + +注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。 diff --git a/data/alpaca_data_en_52k.json.REMOVED.git-id b/data/alpaca_data_en_52k.json.REMOVED.git-id new file mode 100644 index 00000000..5568c425 --- /dev/null +++ b/data/alpaca_data_en_52k.json.REMOVED.git-id @@ -0,0 +1 @@ +3779ddbc040543ab1834ef216c983d6fcc06cc9a \ No newline at end of file diff --git a/data/alpaca_data_zh_51k.json.REMOVED.git-id b/data/alpaca_data_zh_51k.json.REMOVED.git-id new file mode 100644 index 00000000..955f6959 --- /dev/null +++ b/data/alpaca_data_zh_51k.json.REMOVED.git-id @@ -0,0 +1 @@ +fc9a6a3458caca2af8dafc6181773fe10c6d8657 \ No newline at end of file diff --git a/data/alpaca_gpt4_data_en.json.REMOVED.git-id b/data/alpaca_gpt4_data_en.json.REMOVED.git-id new file mode 100644 index 00000000..15985776 --- /dev/null +++ b/data/alpaca_gpt4_data_en.json.REMOVED.git-id @@ -0,0 +1 @@ +25508714b7879a1e5a6764ba7f979a980f549f1a \ No newline at end of file diff --git a/data/alpaca_gpt4_data_zh.json.REMOVED.git-id b/data/alpaca_gpt4_data_zh.json.REMOVED.git-id new file mode 100644 index 00000000..c86d1aea --- /dev/null +++ b/data/alpaca_gpt4_data_zh.json.REMOVED.git-id @@ -0,0 +1 @@ +7cb6a7d11455bddc3d495750a2392683d775b184 \ No newline at end of file diff --git a/data/comparison_gpt4_data_en.json.REMOVED.git-id b/data/comparison_gpt4_data_en.json.REMOVED.git-id new file mode 100644 index 00000000..d7c6f987 --- /dev/null +++ b/data/comparison_gpt4_data_en.json.REMOVED.git-id @@ -0,0 +1 @@ +f437d58b7791609ee91f064551c5c5734a0fd97a \ No newline at end of file diff --git a/data/comparison_gpt4_data_zh.json.REMOVED.git-id b/data/comparison_gpt4_data_zh.json.REMOVED.git-id new file mode 100644 index 00000000..6007329e --- /dev/null +++ b/data/comparison_gpt4_data_zh.json.REMOVED.git-id @@ -0,0 +1 @@ +0e346cf70e633456c7e83f68765361016005447a \ No newline at end of file diff --git a/data/dataset_info.json b/data/dataset_info.json new file mode 100644 index 00000000..029aa6e7 --- /dev/null +++ b/data/dataset_info.json @@ -0,0 +1,97 @@ +{ + "alpaca_en": { + "hf_hub_url": "tatsu-lab/alpaca" + }, + "alpaca_zh": { + "file_name": "alpaca_data_zh_51k.json", + "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311" + }, + "alpaca_gpt4_en": { + "file_name": "alpaca_gpt4_data_en.json", + "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a" + }, + "alpaca_gpt4_zh": { + "file_name": "alpaca_gpt4_data_zh.json", + "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" + }, + "belle_0.5m": { + "hf_hub_url": "BelleGroup/train_0.5M_CN" + }, + "belle_1m": { + "hf_hub_url": "BelleGroup/train_1M_CN" + }, + "belle_2m": { + "hf_hub_url": "BelleGroup/train_2M_CN" + }, + "belle_dialog": { + "hf_hub_url": "BelleGroup/generated_chat_0.4M" + }, + "belle_math": { + "hf_hub_url": "BelleGroup/school_math_0.25M" + }, + "belle_multiturn": { + "hf_hub_url": "BelleGroup/multiturn_chat_0.8M" + }, + "guanaco": { + "hf_hub_url": "JosephusCheung/GuanacoDataset" + }, + "firefly": { + "hf_hub_url": "YeungNLP/firefly-train-1.1M", + "columns": { + "prompt": "input", + "query": "", + "response": "target", + "history": "" + } + }, + "codealpaca": { + "hf_hub_url": "sahil2801/CodeAlpaca-20k" + }, + "alpaca_cot": { + "hf_hub_url": "QingyiSi/Alpaca-CoT" + }, + "webqa": { + "hf_hub_url": "suolyer/webqa", + "columns": { + "prompt": "input", + "query": "", + "response": "output", + "history": "" + } + }, + "ultra_chat": { + "script_url": "ultra_chat", + "columns": { + "prompt": "instruction", + "query": "", + "response": "output", + "history": "history" + } + }, + "example": { + "script_url": "example_dataset", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, + "comparison_gpt4_en": { + "file_name": "comparison_gpt4_data_en.json", + "file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f" + }, + "comparison_gpt4_zh": { + "file_name": "comparison_gpt4_data_zh.json", + "file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0" + }, + "hh_rlhf_en": { + "script_url": "hh_rlhf_en", + "columns": { + "prompt": "instruction", + "query": "", + "response": "output", + "history": "history" + } + } +} diff --git a/data/example_dataset/example_dataset.py b/data/example_dataset/example_dataset.py new file mode 100644 index 00000000..db3e9ffb --- /dev/null +++ b/data/example_dataset/example_dataset.py @@ -0,0 +1,46 @@ +import json +import datasets +from typing import Any, Dict, List + + +_DESCRIPTION = "An example of dataset for LLaMA." +_CITATION = "" +_HOMEPAGE = "" +_LICENSE = "" +_URL = "examples.json" + + +class ExampleDataset(datasets.GeneratorBasedBuilder): + + VERSION = datasets.Version("0.0.0") + + def _info(self) -> datasets.DatasetInfo: + features = datasets.Features({ + "instruction": datasets.Value("string"), + "input": datasets.Value("string"), + "output": datasets.Value("string"), + "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) + }) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION + ) + + def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: + file_path = dl_manager.download(_URL) + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepath": file_path + } + ) + ] + + def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: + example_dataset = json.load(open(filepath, "r", encoding="utf-8")) + for key, example in enumerate(example_dataset): + yield key, example diff --git a/data/example_dataset/examples.json b/data/example_dataset/examples.json new file mode 100644 index 00000000..3b4205e8 --- /dev/null +++ b/data/example_dataset/examples.json @@ -0,0 +1,20 @@ +[ + { + "instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?", + "input": "", + "output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。", + "history": [ + ["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"], + ["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"] + ] + }, + { + "instruction": "好的,谢谢你!", + "input": "", + "output": "不客气,有其他需要帮忙的地方可以继续问我。", + "history": [ + ["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"], + ["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"] + ] + } +] diff --git a/data/hh_rlhf_en/hh_rlhf_en.py b/data/hh_rlhf_en/hh_rlhf_en.py new file mode 100644 index 00000000..8d51e4c4 --- /dev/null +++ b/data/hh_rlhf_en/hh_rlhf_en.py @@ -0,0 +1,97 @@ +import json +import datasets +from typing import Any, Dict, List + + +_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM." +_CITATION = "" +_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf" +_LICENSE = "mit" +_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/" +_URLS = { + "train": [ + _URL + "harmless-base/train.jsonl.gz", + _URL + "helpful-base/train.jsonl.gz", + _URL + "helpful-online/train.jsonl.gz", + _URL + "helpful-rejection-sampled/train.jsonl.gz" + ], + "test": [ + _URL + "harmless-base/test.jsonl.gz", + _URL + "helpful-base/test.jsonl.gz", + _URL + "helpful-online/test.jsonl.gz", + _URL + "helpful-rejection-sampled/test.jsonl.gz" + ] +} + + +class HhRlhfEn(datasets.GeneratorBasedBuilder): + + VERSION = datasets.Version("0.0.0") + + def _info(self) -> datasets.DatasetInfo: + features = datasets.Features({ + "instruction": datasets.Value("string"), + "output": datasets.Sequence(datasets.Value("string")), + "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) + }) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION + ) + + def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: + file_path = dl_manager.download_and_extract(_URLS) + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepaths": file_path["train"] + } + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepaths": file_path["test"] + } + ) + ] + + def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM + key = 0 + for filepath in filepaths: + with open(filepath, "r", encoding="utf-8") as f: + for row in f: + data = json.loads(row) + chosen = data["chosen"] + rejected = data["rejected"] + + assist_idx = rejected.rfind("\n\nAssistant: ") + r_reject = rejected[assist_idx+13:].strip() + assist_idx = chosen.rfind("\n\nAssistant: ") + r_accept = chosen[assist_idx+13:].strip() + + human_idx = chosen.rfind("\n\nHuman: ") + query = chosen[human_idx+9:assist_idx].strip() + prompt = chosen[:human_idx] + history = [] + + while prompt.rfind("\n\nAssistant: ") != -1: + assist_idx = prompt.rfind("\n\nAssistant: ") + human_idx = prompt.rfind("\n\nHuman: ") + if human_idx != -1: + old_query = prompt[human_idx+9:assist_idx].strip() + old_resp = prompt[assist_idx+13:].strip() + history.insert(0, (old_query, old_resp)) + else: + break + prompt = prompt[:human_idx] + + yield key, { + "instruction": query, + "output": [r_accept, r_reject], + "history": history + } + key += 1 diff --git a/data/ultra_chat/ultra_chat.py b/data/ultra_chat/ultra_chat.py new file mode 100644 index 00000000..dd29311c --- /dev/null +++ b/data/ultra_chat/ultra_chat.py @@ -0,0 +1,76 @@ +import json +import datasets +from typing import Any, Dict, List + + +_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data." + +_CITATION = """\ +@misc{UltraChat, + author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen}, + title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\\url{https://github.com/thunlp/ultrachat}}, +} +""" + +_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat" +_LICENSE = "cc-by-nc-4.0" +_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl" + + +class BelleMultiturn(datasets.GeneratorBasedBuilder): + + VERSION = datasets.Version("0.0.0") + + def _info(self) -> datasets.DatasetInfo: + features = datasets.Features({ + "instruction": datasets.Value("string"), + "output": datasets.Value("string"), + "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) + }) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION + ) + + def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: + file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepaths": file_paths + } + ) + ] + + def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM + for filepath in filepaths: + with open(filepath, "r", encoding="utf-8") as f: + for row in f: + try: + data = json.loads(row) + except: + continue + key = data["id"] + content = data["data"] + if len(content) % 2 == 1: + content.pop(-1) + if len(content) < 2: + continue + + query = content[-2] + response = content[-1] + history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)] + + yield key, { + "instruction": query, + "output": response, + "history": history + } diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cli_demo.py b/src/cli_demo.py new file mode 100644 index 00000000..2ded2682 --- /dev/null +++ b/src/cli_demo.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Implements stream chat in command line for LLaMA fine-tuned with PEFT. +# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint + + +import torch +from utils import ModelArguments, auto_configure_device_map, load_pretrained +from transformers import HfArgumentParser + + +def main(): + + parser = HfArgumentParser(ModelArguments) + model_args, = parser.parse_args_into_dataclasses() + model, tokenizer = load_pretrained(model_args) + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + device_map = auto_configure_device_map(torch.cuda.device_count()) + model = dispatch_model(model, device_map) + else: + model = model.cuda() + model.eval() + + def predict(query, history: list): + inputs = tokenizer([query], return_tensors="pt") + inputs = inputs.to(model.device) + gen_kwargs = { + "do_sample": True, + "top_p": 0.9, + "top_k": 40, + "temperature": 0.7, + "num_beams": 1, + "max_new_tokens": 256, + "repetition_penalty": 1.5 + } + with torch.no_grad(): + generation_output = model.generate(**inputs, **gen_kwargs) + outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + history = history + [(query, response)] + return response, history + + history = [] + print("欢迎使用 LLaMA-7B 模型,输入内容即可对话,clear清空对话历史,stop终止程序") + while True: + try: + query = input("\nInput: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "stop": + break + + if query.strip() == "clear": + history = [] + continue + + response, history = predict(query, history) + print("LLaMA-7B:", response) + + +if __name__ == "__main__": + main() diff --git a/src/export_model.py b/src/export_model.py new file mode 100644 index 00000000..b62f10f1 --- /dev/null +++ b/src/export_model.py @@ -0,0 +1,23 @@ +# coding=utf-8 +# Exports the fine-tuned LLaMA model. +# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model + + +from transformers import HfArgumentParser, TrainingArguments +from utils import ModelArguments, load_pretrained + + +def main(): + + parser = HfArgumentParser((ModelArguments, TrainingArguments)) + model_args, training_args = parser.parse_args_into_dataclasses() + + model, tokenizer = load_pretrained(model_args) + model.save_pretrained(training_args.output_dir, max_shard_size="1GB") + tokenizer.save_pretrained(training_args.output_dir) + + print("model and tokenizer have been saved at:", training_args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/train_ppo.py b/src/train_ppo.py new file mode 100644 index 00000000..41f89a57 --- /dev/null +++ b/src/train_ppo.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Implements parameter-efficient PPO training of fine-tuned LLaMA. +# This code is inspired by: +# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py + +import math + +from torch.optim import AdamW + +from transformers.optimization import get_scheduler +from trl import PPOConfig + +from utils import ( + prepare_args, + prepare_data, + load_pretrained, + preprocess_data, + DataCollatorForLLaMA, + PPOTrainerForLLaMA, + plot_loss +) + + +def main(): + + # prepare pretrained model and dataset + model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo") + dataset = prepare_data(model_args, data_args) + model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo") + dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") + data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model) + + ppo_config = PPOConfig( + model_name=model_args.model_name_or_path, + learning_rate=training_args.learning_rate, + mini_batch_size=training_args.per_device_train_batch_size, + batch_size=training_args.per_device_train_batch_size, + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + ppo_epochs=1, + max_grad_norm=training_args.max_grad_norm + ) + + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) + total_train_batch_size = \ + training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + lr_scheduler = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=training_args.warmup_steps, + num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) + ) + + # Initialize our Trainer + ppo_trainer = PPOTrainerForLLaMA( + training_args=training_args, + finetuning_args=finetuning_args, + config=ppo_config, + model=model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=data_collator, + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) + ppo_trainer.save_model() + ppo_trainer.save_state() # must be after save_model + if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args, keys=["loss", "reward"]) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/train_rm.py b/src/train_rm.py new file mode 100644 index 00000000..dd544f3a --- /dev/null +++ b/src/train_rm.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Implements parameter-efficient training of a reward model based on LLaMA. +# This code is inspired by: +# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py +# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py + + +from utils import ( + prepare_args, + prepare_data, + load_pretrained, + preprocess_data, + PairwiseDataCollatorForLLaMA, + PairwiseTrainerForLLaMA, + plot_loss +) + +def main(): + + # prepare pretrained model and dataset + model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm") + dataset = prepare_data(model_args, data_args) + model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm") + dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm") + data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model) + + training_args.remove_unused_columns = False # Important for pairwise dataset + + # Split the dataset + if training_args.do_train: + if data_args.dev_ratio > 1e-6: + dataset = dataset.train_test_split(test_size=data_args.dev_ratio) + trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + trainer_kwargs = {"train_dataset": dataset} + else: # do_eval or do_predict + trainer_kwargs = {"eval_dataset": dataset} + + # Initialize our Trainer + trainer = PairwiseTrainerForLLaMA( + finetuning_args=finetuning_args, + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + **trainer_kwargs + ) + + # Training + if training_args.do_train: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/train_sft.py b/src/train_sft.py new file mode 100644 index 00000000..d34a8e4e --- /dev/null +++ b/src/train_sft.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Implements several parameter-efficient supervised fine-tuning method for LLaMA. +# This code is inspired by +# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py + + +from utils import ( + load_pretrained, + prepare_args, + prepare_data, + preprocess_data, + DataCollatorForLLaMA, + Seq2SeqTrainerForLLaMA, + ComputeMetrics, + get_logits_processor, + plot_loss +) + + +def main(): + + # Prepare pretrained model and dataset + model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft") + dataset = prepare_data(model_args, data_args) + model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") + data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length if \ + training_args.generation_max_length is not None else data_args.max_target_length + training_args.generation_num_beams = data_args.num_beams if \ + data_args.num_beams is not None else training_args.generation_num_beams + + # Split the dataset + if training_args.do_train: + if data_args.dev_ratio > 1e-6: + dataset = dataset.train_test_split(test_size=data_args.dev_ratio) + trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + trainer_kwargs = {"train_dataset": dataset} + else: # do_eval or do_predict + trainer_kwargs = {"eval_dataset": dataset} + + # Initialize our Trainer + trainer = Seq2SeqTrainerForLLaMA( + finetuning_args=finetuning_args, + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + **trainer_kwargs + ) + + # Keyword arguments for `model.generate` + gen_kwargs = { + "do_sample": True, + "top_p": 0.7, + "max_length": data_args.max_source_length + data_args.max_target_length + 1, + "temperature": 0.95, + "logits_processor": get_logits_processor() + } + + # Training + if training_args.do_train: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results, tokenizer) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 00000000..c19e82ad --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,15 @@ +from .common import ( + load_pretrained, + prepare_args, + prepare_data, + preprocess_data +) + +from .data_collator import DataCollatorForLLaMA + +from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA +from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA +from .ppo import PPOTrainerForLLaMA + +from .config import ModelArguments +from .other import auto_configure_device_map, get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py new file mode 100644 index 00000000..db6bfd22 --- /dev/null +++ b/src/utils/common.py @@ -0,0 +1,459 @@ +import os +import sys +import torch +import hashlib +from typing import List, Literal, Optional, Tuple + +import transformers +from transformers import ( + LlamaForCausalLM, + LlamaTokenizer, + HfArgumentParser, + Seq2SeqTrainingArguments +) +from transformers.utils import check_min_version +from transformers.utils.versions import require_version +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer + +import datasets +from datasets import Dataset, concatenate_datasets, load_dataset + +from peft import ( + PeftModel, + TaskType, + LoraConfig, + get_peft_model +) + +from trl import AutoModelForCausalLMWithValueHead + +from .config import ( + ModelArguments, + DataTrainingArguments, + FinetuningArguments +) + +from .other import ( + get_logger, + load_trainable_params, + load_valuehead_params, + print_trainable_params, + prepare_model_for_training, + IGNORE_INDEX, + FINETUNING_ARGS_NAME +) + +check_min_version("4.29.1") +require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") +require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") +require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") + + +logger = get_logger(__name__) + + +def init_adapter( + model: PreTrainedModel, + model_args: ModelArguments, + finetuning_args: FinetuningArguments, + is_trainable: bool +) -> PreTrainedModel: + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + + if finetuning_args.finetuning_type == "none" and is_trainable: + raise ValueError("You cannot use finetuning_type=none while training.") + + if finetuning_args.finetuning_type == "full": + logger.info("Fine-tuning method: Full") + model = model.float() + + if finetuning_args.finetuning_type == "freeze": + logger.info("Fine-tuning method: Freeze") + for name, param in model.named_parameters(): + if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): + param.requires_grad_(False) + else: + param.data = param.data.to(torch.float32) + + if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: + load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: LoRA") + lastest_checkpoint = None + + if model_args.checkpoint_dir is not None: + if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights + checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] + else: + checkpoints_to_merge = model_args.checkpoint_dir + + for checkpoint in checkpoints_to_merge: + model = PeftModel.from_pretrained(model, checkpoint) + model = model.merge_and_unload() + + if len(checkpoints_to_merge) > 0: + logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) + + if lastest_checkpoint is not None: # resume lora training + model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) + + if is_trainable and lastest_checkpoint is None: # create new lora weights while training + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target + ) + model = get_peft_model(model, lora_config) + + return model + + +def load_pretrained( + model_args: ModelArguments, + finetuning_args: Optional[FinetuningArguments] = None, + is_trainable: Optional[bool] = False, + stage: Optional[Literal["sft", "rm", "ppo"]] = "sft" +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + r""" + Loads pretrained model and tokenizer. + + Support both training and inference. + """ + + if (not is_trainable) and (model_args.checkpoint_dir is None): + logger.warning("Checkpoint is not found at evaluation, load the original model.") + finetuning_args = FinetuningArguments(finetuning_type="none") + + if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint + for checkpoint_dir in model_args.checkpoint_dir: + if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): + raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") + logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) + finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)) + if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1: + logger.warning("Only LoRA tuning accepts multiple checkpoints.") + + assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." + + tokenizer = LlamaTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + padding_side="left" + ) + tokenizer.pad_token_id = 0 # set as the token + + # Quantization configurations (using bitsandbytes library). + config_kwargs = {} + if model_args.quantization_bit is not None: + assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." + + require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") + from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible + cuda = get_cuda_lib_handle() + cc = get_compute_capability(cuda) + assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization." + + config_kwargs["load_in_8bit"] = True + config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit + logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) + + # Load and prepare pretrained models (without valuehead). + model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs) + model = prepare_model_for_training(model) if is_trainable else model + model = init_adapter(model, model_args, finetuning_args, is_trainable) + + if not is_trainable: + model.requires_grad_(False) # fix all model params + model = model.half() # cast all params to float16 for inference + + if stage == "rm" or stage == "ppo": # add value head + model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + + if stage == "ppo": # load reward model + assert is_trainable, "PPO stage cannot be performed at evaluation." + assert model_args.reward_model is not None, "Reward model is necessary for PPO training." + logger.info("Load reward model from {}".format(model_args.reward_model)) + model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) + load_valuehead_params(model, model_args.reward_model) + + # Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model + # To meet the compliance requirements of the transformers library + if model_args.quantization_bit is not None: + model._is_int8_training_enabled = True + + print_trainable_params(model) + + return model, tokenizer + + +def prepare_args( + stage: Literal["sft", "rm", "ppo"] +) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() + + # Setup logging + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) + if stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.") + + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") + + if training_args.do_predict and (not training_args.predict_with_generate): + raise ValueError("Please enable `predict_with_generate` for saving model predictions.") + + if model_args.quantization_bit is not None and (not training_args.do_train): + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if training_args.do_train and (not training_args.fp16): + logger.warning("We recommend enable fp16 mixed precision training for LLaMA.") + + if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: + logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") + training_args.ddp_find_unused_parameters = False + + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + + # Log on each process the small summary: + logger.info( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args + + +def prepare_data( + model_args: ModelArguments, + data_args: DataTrainingArguments +) -> Dataset: + + def checksum(file_path, hash): + with open(file_path, "rb") as datafile: + binary_data = datafile.read() + sha1 = hashlib.sha1(binary_data).hexdigest() + if sha1 != hash: + logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) + + max_samples = data_args.max_samples + all_datasets: List[Dataset] = [] # support multiple datasets + + for dataset_attr in data_args.dataset_list: + + logger.info("Loading dataset {}...".format(dataset_attr)) + + if dataset_attr.load_from == "hf_hub": + raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir) + elif dataset_attr.load_from == "script": + raw_datasets = load_dataset( + os.path.join(data_args.dataset_dir, dataset_attr.dataset_name), + cache_dir=model_args.cache_dir + ) + elif dataset_attr.load_from == "file": + data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv + extension = dataset_attr.file_name.split(".")[-1] + + if dataset_attr.file_sha1 is not None: + checksum(data_file, dataset_attr.file_sha1) + else: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + + raw_datasets = load_dataset( + extension, + data_files=data_file, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None + ) + else: + raise NotImplementedError + + dataset = raw_datasets[data_args.split] + + if max_samples is not None: + max_samples_temp = min(len(dataset), max_samples) + dataset = dataset.select(range(max_samples_temp)) + + dummy_data = [None] * len(dataset) + for column_name, target_name in [ + ("prompt_column", "prompt"), + ("query_column", "query"), + ("response_column", "response"), + ("history_column", "history") + ]: # every dataset will have 4 columns same as each other + if getattr(dataset_attr, column_name) != target_name: + if getattr(dataset_attr, column_name): + dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) + else: # None or empty string + dataset = dataset.add_column(target_name, dummy_data) + all_datasets.append(dataset) + + if len(data_args.dataset_list) == 1: + all_datasets = all_datasets[0] + else: + all_datasets = concatenate_datasets(all_datasets) + + return all_datasets + + +def preprocess_data( + dataset: Dataset, + tokenizer: PreTrainedTokenizer, + data_args: DataTrainingArguments, + training_args: Seq2SeqTrainingArguments, + stage: Optional[Literal["sft", "rm", "ppo"]] = "sft" +) -> Dataset: + + column_names = list(dataset.column_names) + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + def format_example(examples): # support question with a single answer or multiple answers + for i in range(len(examples["prompt"])): + if examples["prompt"][i] and examples["response"][i]: + query, answer = examples["prompt"][i], examples["response"][i] + if examples["query"][i]: + query += examples["query"][i] + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\n" + prefix + if examples["history"][i]: + history = examples["history"][i] + for old_query, response in history: + prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) + prompt += "Human: {}\nAssistant: ".format(query) + yield prompt, answer + + def preprocess_supervised_dataset(examples): + # build inputs with format `X Y ` and labels with format ` ... Y ` + model_inputs = {"input_ids": [], "labels": []} + for prompt, answer in format_example(examples): + source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + target_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(source_ids) > data_args.max_source_length - 1: # bos token + source_ids = source_ids[:data_args.max_source_length - 1] + if len(target_ids) > data_args.max_target_length - 1: # eos token + target_ids = target_ids[:data_args.max_target_length - 1] + + input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] + labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + return model_inputs + + def preprocess_evaluation_dataset(examples): + # build inputs with format `X ` and labels with format `Y ` + model_inputs = {"input_ids": [], "labels": []} + for prompt, answer in format_example(examples): + source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + target_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(source_ids) > data_args.max_source_length - 1: # bos token + source_ids = source_ids[:data_args.max_source_length - 1] + if len(target_ids) > data_args.max_target_length - 1: # bos token + target_ids = target_ids[:data_args.max_target_length - 1] + + input_ids = source_ids + [tokenizer.bos_token_id] + labels = target_ids + [tokenizer.bos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + return model_inputs + + def preprocess_pairwise_dataset(examples): + # build input pairs with format `X Y1 ` and `X Y2 ` + model_inputs = {"accept_ids": [], "reject_ids": []} + for prompt, answer in format_example(examples): + source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) + reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) + + if len(source_ids) > data_args.max_source_length - 1: # bos token + source_ids = source_ids[:data_args.max_source_length - 1] + if len(accept_ids) > data_args.max_target_length - 1: # eos token + accept_ids = accept_ids[:data_args.max_target_length - 1] + if len(reject_ids) > data_args.max_target_length - 1: # eos token + reject_ids = reject_ids[:data_args.max_target_length - 1] + + accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] + reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id] + + model_inputs["accept_ids"].append(accept_ids) + model_inputs["reject_ids"].append(reject_ids) + return model_inputs + + def print_sft_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) + print("label_ids:\n{}".format(example["labels"])) + print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))) + + def print_pairwise_dataset_example(example): + print("accept_ids:\n{}".format(example["accept_ids"])) + print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) + print("reject_ids:\n{}".format(example["reject_ids"])) + print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) + + def print_ppo_dataset_example(example): + print("input_ids:\n{}".format(example["input_ids"])) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) + + if stage == "sft": + if (not training_args.do_train) and training_args.predict_with_generate: # with generation + preprocess_function = preprocess_evaluation_dataset + else: # without generation + preprocess_function = preprocess_supervised_dataset + elif stage == "rm": + preprocess_function = preprocess_pairwise_dataset + elif stage == "ppo": + preprocess_function = preprocess_evaluation_dataset + + with training_args.main_process_first(desc="dataset map pre-processing"): + dataset = dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset" + ) + + if stage == "sft": + print_sft_dataset_example(dataset[0]) + elif stage == "rm": + print_pairwise_dataset_example(dataset[0]) + elif stage == "ppo": + print_ppo_dataset_example(dataset[0]) + + return dataset diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 00000000..fe35a6ea --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,212 @@ +import os +import json +from typing import List, Literal, Optional +from dataclasses import asdict, dataclass, field + + +@dataclass +class DatasetAttr: + + load_from: str + dataset_name: Optional[str] = None + file_name: Optional[str] = None + file_sha1: Optional[str] = None + + def __post_init__(self): + self.prompt_column = "instruction" + self.query_column = "input" + self.response_column = "output" + self.history_column = None + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} + ) + use_fast_tokenizer: Optional[bool] = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} + ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={"help": "Will use the token generated when running `huggingface-cli login`."} + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model."} + ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + + def __post_init__(self): + if self.checkpoint_dir is not None: # support merging lora weights + self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + dataset: Optional[str] = field( + default="alpaca_zh", + metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} + ) + dataset_dir: Optional[str] = field( + default="data", + metadata={"help": "The name of the folder containing datasets."} + ) + split: Optional[str] = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."} + ) + overwrite_cache: Optional[bool] = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."} + ) + max_source_length: Optional[int] = field( + default=512, + metadata={"help": "The maximum total input sequence length after tokenization."} + ) + max_target_length: Optional[int] = field( + default=512, + metadata={"help": "The maximum total output sequence length after tokenization."} + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} + ) + num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} + ) + ignore_pad_token_for_loss: Optional[bool] = field( + default=True, + metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} + ) + source_prefix: Optional[str] = field( + default=None, + metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + dev_ratio: Optional[float] = field( + default=0, + metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} + ) + + def __post_init__(self): # support mixing multiple datasets + dataset_names = [ds.strip() for ds in self.dataset.split(",")] + dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r")) + + self.dataset_list: List[DatasetAttr] = [] + for name in dataset_names: + if name not in dataset_info: + raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) + + if "hf_hub_url" in dataset_info[name]: + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + elif "script_url" in dataset_info[name]: + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + else: + dataset_attr = DatasetAttr( + "file", + file_name=dataset_info[name]["file_name"], + file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None + ) + + if "columns" in dataset_info[name]: + dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) + dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) + dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + + self.dataset_list.append(dataset_attr) + + +@dataclass +class FinetuningArguments: + """ + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."} + ) + num_layer_trainable: Optional[int] = field( + default=3, + metadata={"help": "Number of trainable layers for Freeze fine-tuning."} + ) + name_module_trainable: Optional[Literal["mlp", "qkv"]] = field( + default="mlp", + metadata={"help": "Name of trainable modules for Freeze fine-tuning."} + ) + lora_rank: Optional[int] = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} + ) + lora_alpha: Optional[float] = field( + default=32.0, + metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"} + ) + lora_dropout: Optional[float] = field( + default=0.1, + metadata={"help": "Dropout rate for the LoRA fine-tuning."} + ) + lora_target: Optional[str] = field( + default="q_proj,v_proj", + metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."} + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) + + def __post_init__(self): + if isinstance(self.lora_target, str): + self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA + + if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)] + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] + + if self.name_module_trainable == "mlp": + self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] + elif self.name_module_trainable == "qkv": + self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] + + assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """Create an instance from the content of `json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py new file mode 100644 index 00000000..2932f2ba --- /dev/null +++ b/src/utils/data_collator.py @@ -0,0 +1,67 @@ +import torch + +from typing import Dict, Optional, Sequence, Union + +from transformers import DataCollatorWithPadding +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer + +from .other import IGNORE_INDEX + + +class DataCollatorForLLaMA(DataCollatorWithPadding): + r""" + Data collator for LLaMA. It is capable of dynamically padding for batched data. + """ + def __init__( + self, + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, + ignore_pad_token_for_loss: Optional[bool] = False + ): + super().__init__(tokenizer, padding=True) + self.model = model + self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id + + def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: + r""" + Generates attention masks for left-padded sequences. + """ + batch_size, seq_length = input_ids.size() + attention_mask = torch.ones((batch_size, seq_length), device=device) + for i, seq in enumerate(input_ids): + attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding + attention_mask = attention_mask.bool() + return attention_mask + + def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We adopt left-padding in both training and evaluation. + """ + if isinstance(features[0]["input_ids"], torch.Tensor): + input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features] + else: + input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features] + + if "labels" in features[0]: + if isinstance(features[0]["labels"], torch.Tensor): + labels = [feature["labels"].clone().detach().flip(0) for feature in features] + else: + labels = [torch.tensor(feature["labels"]).flip(0) for feature in features] + input_ids = input_ids + labels # pad them to the same length + + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1) + + batch = {} + + if "labels" in features[0]: + input_ids, labels = input_ids.split(len(features), dim=0) + labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id) + batch["labels"] = labels + + batch["input_ids"] = input_ids + batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device) + + return batch diff --git a/src/utils/other.py b/src/utils/other.py new file mode 100644 index 00000000..8a30dd35 --- /dev/null +++ b/src/utils/other.py @@ -0,0 +1,205 @@ +import os +import sys +import json +import torch +import logging +from typing import Dict, List, Optional + +from transformers import Seq2SeqTrainingArguments +from transformers.trainer import TRAINER_STATE_NAME +from transformers.modeling_utils import PreTrainedModel +from transformers.generation.utils import LogitsProcessorList +from transformers.generation.logits_process import LogitsProcessor + +from peft.utils.other import WEIGHTS_NAME + + +IGNORE_INDEX = -100 +VALUE_HEAD_FILE_NAME = "value_head.bin" +FINETUNING_ARGS_NAME = "finetuning_args.json" + + +logger = logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[logging.StreamHandler(sys.stdout)] +) + + +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +# Avoid runtime error in model.generate(do_sample=True). +# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54 +class InvalidScoreLogitsProcessor(LogitsProcessor): + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + return logits_processor + + +# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 +# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 +def prepare_model_for_training( + model: PreTrainedModel, + output_embedding_layer_name: Optional[str] = "lm_head", + use_gradient_checkpointing: Optional[bool] = True, + layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting +) -> PreTrainedModel: + + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + param.data = param.data.to(torch.float32) + + if use_gradient_checkpointing: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + model.config.use_cache = False # turn off when gradient checkpointing is enabled + + if hasattr(model, output_embedding_layer_name): + output_embedding_layer = getattr(model, output_embedding_layer_name) + input_dtype = output_embedding_layer.weight.dtype + + class CastOutputToFloat(torch.nn.Sequential): + + def forward(self, x): + return super().forward(x.to(input_dtype)).to(torch.float32) + + setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) + + return model + + +def print_trainable_params(model: torch.nn.Module) -> None: + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param)) + + +def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters + state_dict = model.state_dict() + filtered_state_dict = {} + + for k, v in model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k].cpu().clone().detach() + + return filtered_state_dict + + +def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: + weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) + assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights." + model_state_dict = torch.load(weights_file, map_location="cpu") + model.load_state_dict(model_state_dict, strict=False) # skip missing keys + + +def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: + valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) + assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights." + valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") + model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) + model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) + model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) + model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) + + +def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: + r""" + Configures device map for LLaMA. + + Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8 + """ + num_layers = 28 + layers_per_gpu = 30 / num_gpus + device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0} + added_layers = 2 + target_gpu = 0 + + for i in range(num_layers): + if added_layers >= layers_per_gpu: + target_gpu += 1 + added_layers = 0 + assert target_gpu < num_gpus + device_map[f"model.layers.{i}"] = target_gpu + added_layers += 1 + + return device_map + + +def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: + """ + EMA implementation according to TensorBoard. + """ + last = scalars[0] + smoothed = list() + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None: + import matplotlib.pyplot as plt + data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), label="smoothed") + plt.title("training {} of {}".format(key, training_args.output_dir)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100) + print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key))) diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py new file mode 100644 index 00000000..3c2aa21d --- /dev/null +++ b/src/utils/pairwise.py @@ -0,0 +1,51 @@ +import torch +from typing import Dict, Sequence, Union + +from .data_collator import DataCollatorForLLaMA + +from .peft_trainer import PeftTrainer + +from .other import get_logger + +logger = get_logger(__name__) + + +class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA): + r""" + Data collator for pairwise data. + """ + + def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] + return super().__call__(features) + + +class PairwiseTrainerForLLaMA(PeftTrainer): + r""" + Inherits PeftTrainer to compute pairwise loss. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.can_return_loss = True # override property to return eval_loss + + def compute_loss(self, model, inputs, return_outputs=False): + r""" + Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. + + We use score on the EOS token to represent reward of the whole sentence. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + """ + batch_size = inputs["input_ids"].size(0) // 2 + _, _, values = model(**inputs) + r_accept, r_reject = values[:, -1].split(batch_size, dim=0) + loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() + outputs = {"r_accept": r_accept, "r_reject": r_reject} + return (loss, outputs) if return_outputs else loss diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py new file mode 100644 index 00000000..57d54a8d --- /dev/null +++ b/src/utils/peft_trainer.py @@ -0,0 +1,78 @@ +import os +import torch +from typing import Dict, Optional + +from transformers import Seq2SeqTrainer +from transformers.trainer import TRAINING_ARGS_NAME +from transformers.modeling_utils import unwrap_model + +from peft.utils.other import WEIGHTS_NAME + +from .config import FinetuningArguments + +from .other import ( + get_logger, + get_state_dict, + load_trainable_params, + load_valuehead_params, + FINETUNING_ARGS_NAME, + VALUE_HEAD_FILE_NAME +) + + +logger = get_logger(__name__) + + +class PeftTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + """ + + def __init__(self, finetuning_args: FinetuningArguments, **kwargs): + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: + r""" + Saves trainable parameters as model checkpoint. + + This function will only be executed at the process zero. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + """ + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + model = unwrap_model(self.model) + + if hasattr(model, "pretrained_model"): # for models with valuehead + backbone_model = getattr(model, "pretrained_model") + else: + backbone_model = model + + if hasattr(backbone_model, "peft_config"): # peft methods + backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights + else: + torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights + + if hasattr(model, "v_head"): # save valuehead weights + torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) + + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) + + def _load_best_model(self): + r""" + Loads trainable parameters from model checkpoint. + + Subclass and override to inject custom behavior. It should not be directly used by external scripts. + """ + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + model = unwrap_model(self.model) + if hasattr(model, "peft_config"): # peft methods + model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter")) + else: + load_trainable_params(model, self.state.best_model_checkpoint) + + if hasattr(model, "v_head"): + load_valuehead_params(model, self.state.best_model_checkpoint) diff --git a/src/utils/ppo.py b/src/utils/ppo.py new file mode 100644 index 00000000..85c69505 --- /dev/null +++ b/src/utils/ppo.py @@ -0,0 +1,241 @@ +import os +import math +import torch +from tqdm import tqdm +from typing import Callable, Dict, List, Literal, Optional, Tuple + +from transformers import Seq2SeqTrainingArguments +from transformers.trainer import TrainerState +from transformers.modeling_utils import PreTrainedModel + +from trl import PPOTrainer, AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits + +from .peft_trainer import PeftTrainer + +from .config import FinetuningArguments + +from .other import ( + AverageMeter, + get_logger, + get_logits_processor +) + + +logger = get_logger(__name__) + + +def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: + if target == "reward": # save original head temporarily + valuehead_state_dict = model.v_head.state_dict() + + setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) + setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) + + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "{}_head_weight".format(target)), + "summary.bias": getattr(model, "{}_head_bias".format(target)) + }) + + +class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): + r""" + Inherits PPOTrainer. + """ + + def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs): + PPOTrainer.__init__(self, **kwargs) + self.args = training_args + self.finetuning_args = finetuning_args + self.state = TrainerState() + self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) + + def ppo_train(self, max_target_length: int) -> None: + r""" + Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. + """ + total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size + len_dataloader = len(self.dataloader) + num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1) + num_examples = len(self.dataset) + num_train_epochs = self.args.num_train_epochs + max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) + + if self.is_world_process_zero(): + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") + + # Keyword arguments for `model.generate` + gen_kwargs = { + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.tokenizer.pad_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "logits_processor": get_logits_processor() + } + output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) + unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) + + dataiter = iter(self.dataloader) + steps_trained = 0 + loss_meter = AverageMeter() + reward_meter = AverageMeter() + + for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()): + + for _ in range(self.config.gradient_accumulation_steps): + + batch = next(dataiter) + steps_trained += 1 + + unwrapped_model.gradient_checkpointing_disable() + unwrapped_model.config.use_cache = True + + # Get response from LLaMA + query_tensors: torch.Tensor = batch["input_ids"] + response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs) + + queries: List[torch.Tensor] = [] + responses: List[torch.Tensor] = [] + for i in range(len(query_tensors)): + query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] + response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 + queries.append(query_tensors[i, query_length:]) # remove padding from left + if response_length < 2: # make response have at least 2 tokens + responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id)) + else: + responses.append(response_tensors[i, :response_length]) # remove padding from right + + # Compute rewards + replace_model(unwrapped_model, target="reward") + _, _, values = self.model(**self.prepare_model_inputs(queries, responses)) + rewards = [reward for reward in values[:, -1]] + replace_model(unwrapped_model, target="default") # make sure the model is default at the end + + # Run PPO step + unwrapped_model.gradient_checkpointing_enable() + unwrapped_model.config.use_cache = False + + stats = self.step(queries, responses, rewards) + + loss_meter.update(stats["ppo/loss/total"]) + reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards)) + + if steps_trained == len_dataloader: + dataiter = iter(self.dataloader) + steps_trained = 0 + + if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: + logs = { + "loss": round(loss_meter.avg, 4), + "reward": round(reward_meter.avg, 4), + "learning_rate": stats["ppo/learning_rate"], + "epoch": round(step / num_steps_per_epoch, 2) + } + print(logs) + logs["step"] = step + self.state.log_history.append(logs) + loss_meter.reset() + reward_meter.reset() + + if (step+1) % self.args.save_steps == 0: # save checkpoint + self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) + + @torch.no_grad() + def generate( + self, + inputs: Dict[str, torch.Tensor], + length_sampler: Callable = None, + return_prompt: bool = True, + **generation_kwargs, + ) -> torch.Tensor: + r""" + Generates model's responses given queries. + + Subclass and override to inject custom behavior. + """ + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + unwrapped_model = self.accelerator.unwrap_model(self.model) + + response = unwrapped_model.generate(**inputs, **generation_kwargs) + + # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop + # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 + if unwrapped_model.pretrained_model.generation_config._from_model_config: + unwrapped_model.pretrained_model.generation_config._from_model_config = False + + if not return_prompt and not self.is_encoder_decoder: + return response[:, inputs["input_ids"].size(1):] + return response + + def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator([{"input_ids": ids} for ids in input_ids]) + input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None} + input_data.pop("labels", None) # we don't want to compute LM losses + return input_data + + @PPODecorators.empty_cuda_cache() + def batched_forward_pass( + self, + model: AutoModelForCausalLMWithValueHead, + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + ): + r""" + Calculates model outputs in multiple batches. + + Subclass and override to inject custom behavior. + """ + bs = len(model_inputs["input_ids"]) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + for i in range(int(bs / fbs)): + input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()} + input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences + logits, _, values = model(**input_kwargs) + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + + masks = torch.zeros_like(input_ids) + for j in range(fbs): + start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() + masks[j][start:] = 1 + if len(masks[j][start:]) < 2: + raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.") + + all_logits.append(logits) + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1], + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + def save_model(self, output_dir: Optional[str] = None) -> None: + r""" + Saves model checkpoint. + + Subclass and override to inject custom behavior. + """ + if self.args.should_save: + self._save(output_dir) diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py new file mode 100644 index 00000000..4a48393b --- /dev/null +++ b/src/utils/seq2seq.py @@ -0,0 +1,96 @@ +import os +import json +import numpy as np +from dataclasses import dataclass +from typing import Dict, List, Sequence, Tuple, Union + +from transformers.trainer import PredictionOutput +from transformers.tokenization_utils import PreTrainedTokenizer + +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction + +from .peft_trainer import PeftTrainer + +from .other import get_logger, IGNORE_INDEX + + +logger = get_logger(__name__) + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA. + + Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 + """ + + tokenizer: PreTrainedTokenizer + + def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True. + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + for pred, label in zip(preds, labels): + pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query + hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) + reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) + + if len(" ".join(hypothesis).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} + + +class Seq2SeqTrainerForLLaMA(PeftTrainer): + r""" + Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def save_predictions( + self, + predict_results: PredictionOutput, + tokenizer: PreTrainedTokenizer + ) -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) + labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) + + preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries + preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds] + labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels] + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for pred, label in zip(preds, labels): + res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) diff --git a/src/web_demo.py b/src/web_demo.py new file mode 100644 index 00000000..8801ee18 --- /dev/null +++ b/src/web_demo.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Implements user interface in browser for LLaMA fine-tuned with PEFT. +# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint + + +import torch +import mdtex2html +import gradio as gr + +from utils import ModelArguments, auto_configure_device_map, load_pretrained +from transformers import HfArgumentParser + + +parser = HfArgumentParser(ModelArguments) +model_args, = parser.parse_args_into_dataclasses() +model, tokenizer = load_pretrained(model_args) +if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + device_map = auto_configure_device_map(torch.cuda.device_count()) + model = dispatch_model(model, device_map) +else: + model = model.cuda() +model.eval() + + +"""Override Chatbot.postprocess""" + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert((message)), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split('`') + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = f'
' + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
"+line + text = "".join(lines) + return text + + +def predict(input, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(input), "")) + + inputs = tokenizer([input], return_tensors="pt") + inputs = inputs.to(model.device) + gen_kwargs = { + "do_sample": True, + "top_p": top_p, + "temperature": temperature, + "num_beams": 1, + "max_length": max_length, + "repetition_penalty": 1.0 + } + with torch.no_grad(): + generation_output = model.generate(**inputs, **gen_kwargs) + outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + history = history + [(input, response)] + chatbot[-1] = (parse_text(input), parse_text(response)) + yield chatbot, history + + +def reset_user_input(): + return gr.update(value='') + + +def reset_state(): + return [], [] + + +with gr.Blocks() as demo: + gr.HTML("""

ChatGLM-Efficient-Tuning

""") + + chatbot = gr.Chatbot() + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + + history = gr.State([]) + + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) + submitBtn.click(reset_user_input, [], [user_input]) + + emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + +demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)