mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-30 18:22:51 +08:00
Initial commit
Former-commit-id: 769c6ab56be0c9d26e9289f61ac54a4068d935c1
This commit is contained in:
commit
54b8ce7b63
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
29
README.md
Normal file
29
README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# LLaMA Efficient Tuning
|
||||
|
||||
1. Download the weights of the LLaMA models.
|
||||
2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)
|
||||
|
||||
```python
|
||||
python convert_llama_weights_to_hf.py \
|
||||
--input_dir path_to_llama_weights --model_size 7B --output_dir llama_7b
|
||||
```
|
||||
|
||||
3. Fine-tune the LLaMA models.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||
--model_name_or_path llama_7b \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_sft_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--fp16
|
||||
```
|
53
data/README.md
Normal file
53
data/README.md
Normal file
@ -0,0 +1,53 @@
|
||||
Data format in `dataset_info.json`:
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
|
||||
"columns": {
|
||||
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
|
||||
"query": "the name of the column in the datasets containing the queries. (default: input)",
|
||||
"response": "the name of the column in the datasets containing the responses. (default: output)",
|
||||
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`dataset_info.json` 中的数据集定义格式:
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
部分预置数据集简介:
|
||||
|
||||
| 数据集名称 | 规模 | 描述 |
|
||||
| --- | --- | --- |
|
||||
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
|
||||
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
|
||||
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
|
||||
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
|
||||
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
|
||||
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
|
||||
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
|
||||
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
|
||||
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
|
||||
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
|
||||
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
|
||||
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
|
||||
|
||||
注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
|
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
f437d58b7791609ee91f064551c5c5734a0fd97a
|
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
@ -0,0 +1 @@
|
||||
0e346cf70e633456c7e83f68765361016005447a
|
97
data/dataset_info.json
Normal file
97
data/dataset_info.json
Normal file
@ -0,0 +1,97 @@
|
||||
{
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "tatsu-lab/alpaca"
|
||||
},
|
||||
"alpaca_zh": {
|
||||
"file_name": "alpaca_data_zh_51k.json",
|
||||
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
|
||||
},
|
||||
"alpaca_gpt4_en": {
|
||||
"file_name": "alpaca_gpt4_data_en.json",
|
||||
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
|
||||
},
|
||||
"alpaca_gpt4_zh": {
|
||||
"file_name": "alpaca_gpt4_data_zh.json",
|
||||
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
|
||||
},
|
||||
"belle_0.5m": {
|
||||
"hf_hub_url": "BelleGroup/train_0.5M_CN"
|
||||
},
|
||||
"belle_1m": {
|
||||
"hf_hub_url": "BelleGroup/train_1M_CN"
|
||||
},
|
||||
"belle_2m": {
|
||||
"hf_hub_url": "BelleGroup/train_2M_CN"
|
||||
},
|
||||
"belle_dialog": {
|
||||
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
|
||||
},
|
||||
"belle_math": {
|
||||
"hf_hub_url": "BelleGroup/school_math_0.25M"
|
||||
},
|
||||
"belle_multiturn": {
|
||||
"hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
|
||||
},
|
||||
"guanaco": {
|
||||
"hf_hub_url": "JosephusCheung/GuanacoDataset"
|
||||
},
|
||||
"firefly": {
|
||||
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
|
||||
"columns": {
|
||||
"prompt": "input",
|
||||
"query": "",
|
||||
"response": "target",
|
||||
"history": ""
|
||||
}
|
||||
},
|
||||
"codealpaca": {
|
||||
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
|
||||
},
|
||||
"alpaca_cot": {
|
||||
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
||||
},
|
||||
"webqa": {
|
||||
"hf_hub_url": "suolyer/webqa",
|
||||
"columns": {
|
||||
"prompt": "input",
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": ""
|
||||
}
|
||||
},
|
||||
"ultra_chat": {
|
||||
"script_url": "ultra_chat",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"example": {
|
||||
"script_url": "example_dataset",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
},
|
||||
"comparison_gpt4_en": {
|
||||
"file_name": "comparison_gpt4_data_en.json",
|
||||
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
|
||||
},
|
||||
"comparison_gpt4_zh": {
|
||||
"file_name": "comparison_gpt4_data_zh.json",
|
||||
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
|
||||
},
|
||||
"hh_rlhf_en": {
|
||||
"script_url": "hh_rlhf_en",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
}
|
46
data/example_dataset/example_dataset.py
Normal file
46
data/example_dataset/example_dataset.py
Normal file
@ -0,0 +1,46 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset for LLaMA."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
_URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": file_path
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
for key, example in enumerate(example_dataset):
|
||||
yield key, example
|
20
data/example_dataset/examples.json
Normal file
20
data/example_dataset/examples.json
Normal file
@ -0,0 +1,20 @@
|
||||
[
|
||||
{
|
||||
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
|
||||
"input": "",
|
||||
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
|
||||
"history": [
|
||||
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
|
||||
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
|
||||
]
|
||||
},
|
||||
{
|
||||
"instruction": "好的,谢谢你!",
|
||||
"input": "",
|
||||
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
|
||||
"history": [
|
||||
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
|
||||
["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"]
|
||||
]
|
||||
}
|
||||
]
|
97
data/hh_rlhf_en/hh_rlhf_en.py
Normal file
97
data/hh_rlhf_en/hh_rlhf_en.py
Normal file
@ -0,0 +1,97 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||
_URLS = {
|
||||
"train": [
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
_URL + "helpful-base/train.jsonl.gz",
|
||||
_URL + "helpful-online/train.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/train.jsonl.gz"
|
||||
],
|
||||
"test": [
|
||||
_URL + "harmless-base/test.jsonl.gz",
|
||||
_URL + "helpful-base/test.jsonl.gz",
|
||||
_URL + "helpful-online/test.jsonl.gz",
|
||||
_URL + "helpful-rejection-sampled/test.jsonl.gz"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Sequence(datasets.Value("string")),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download_and_extract(_URLS)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["train"]
|
||||
}
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepaths": file_path["test"]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for row in f:
|
||||
data = json.loads(row)
|
||||
chosen = data["chosen"]
|
||||
rejected = data["rejected"]
|
||||
|
||||
assist_idx = rejected.rfind("\n\nAssistant: ")
|
||||
r_reject = rejected[assist_idx+13:].strip()
|
||||
assist_idx = chosen.rfind("\n\nAssistant: ")
|
||||
r_accept = chosen[assist_idx+13:].strip()
|
||||
|
||||
human_idx = chosen.rfind("\n\nHuman: ")
|
||||
query = chosen[human_idx+9:assist_idx].strip()
|
||||
prompt = chosen[:human_idx]
|
||||
history = []
|
||||
|
||||
while prompt.rfind("\n\nAssistant: ") != -1:
|
||||
assist_idx = prompt.rfind("\n\nAssistant: ")
|
||||
human_idx = prompt.rfind("\n\nHuman: ")
|
||||
if human_idx != -1:
|
||||
old_query = prompt[human_idx+9:assist_idx].strip()
|
||||
old_resp = prompt[assist_idx+13:].strip()
|
||||
history.insert(0, (old_query, old_resp))
|
||||
else:
|
||||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": [r_accept, r_reject],
|
||||
"history": history
|
||||
}
|
||||
key += 1
|
76
data/ultra_chat/ultra_chat.py
Normal file
76
data/ultra_chat/ultra_chat.py
Normal file
@ -0,0 +1,76 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||
|
||||
_CITATION = """\
|
||||
@misc{UltraChat,
|
||||
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
|
||||
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\\url{https://github.com/thunlp/ultrachat}},
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
|
||||
_LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepaths": file_paths
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
except:
|
||||
continue
|
||||
key = data["id"]
|
||||
content = data["data"]
|
||||
if len(content) % 2 == 1:
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
continue
|
||||
|
||||
query = content[-2]
|
||||
response = content[-1]
|
||||
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": response,
|
||||
"history": history
|
||||
}
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
66
src/cli_demo.py
Normal file
66
src/cli_demo.py
Normal file
@ -0,0 +1,66 @@
|
||||
# coding=utf-8
|
||||
# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
|
||||
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count())
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
def predict(query, history: list):
|
||||
inputs = tokenizer([query], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"temperature": 0.7,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": 256,
|
||||
"repetition_penalty": 1.5
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(**inputs, **gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
history = history + [(query, response)]
|
||||
return response, history
|
||||
|
||||
history = []
|
||||
print("欢迎使用 LLaMA-7B 模型,输入内容即可对话,clear清空对话历史,stop终止程序")
|
||||
while True:
|
||||
try:
|
||||
query = input("\nInput: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "stop":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
history = []
|
||||
continue
|
||||
|
||||
response, history = predict(query, history)
|
||||
print("LLaMA-7B:", response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
23
src/export_model.py
Normal file
23
src/export_model.py
Normal file
@ -0,0 +1,23 @@
|
||||
# coding=utf-8
|
||||
# Exports the fine-tuned LLaMA model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from utils import ModelArguments, load_pretrained
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
model_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="1GB")
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
80
src/train_ppo.py
Normal file
80
src/train_ppo.py
Normal file
@ -0,0 +1,80 @@
|
||||
# coding=utf-8
|
||||
# Implements parameter-efficient PPO training of fine-tuned LLaMA.
|
||||
# This code is inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
|
||||
import math
|
||||
|
||||
from torch.optim import AdamW
|
||||
|
||||
from transformers.optimization import get_scheduler
|
||||
from trl import PPOConfig
|
||||
|
||||
from utils import (
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
PPOTrainerForLLaMA,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
learning_rate=training_args.learning_rate,
|
||||
mini_batch_size=training_args.per_device_train_batch_size,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm
|
||||
)
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
||||
total_train_batch_size = \
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=training_args.warmup_steps,
|
||||
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOTrainerForLLaMA(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be after save_model
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "reward"])
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
72
src/train_rm.py
Normal file
72
src/train_rm.py
Normal file
@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Implements parameter-efficient training of a reward model based on LLaMA.
|
||||
# This code is inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
|
||||
from utils import (
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
PairwiseDataCollatorForLLaMA,
|
||||
PairwiseTrainerForLLaMA,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
def main():
|
||||
|
||||
# prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
|
||||
|
||||
training_args.remove_unused_columns = False # Important for pairwise dataset
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwiseTrainerForLLaMA(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
95
src/train_sft.py
Normal file
95
src/train_sft.py
Normal file
@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Implements several parameter-efficient supervised fine-tuning method for LLaMA.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
|
||||
from utils import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
Seq2SeqTrainerForLLaMA,
|
||||
ComputeMetrics,
|
||||
get_logits_processor,
|
||||
plot_loss
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Prepare pretrained model and dataset
|
||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
training_args.generation_max_length is not None else data_args.max_target_length
|
||||
training_args.generation_num_beams = data_args.num_beams if \
|
||||
data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainerForLLaMA(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"max_length": data_args.max_source_length + data_args.max_target_length + 1,
|
||||
"temperature": 0.95,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results, tokenizer)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
15
src/utils/__init__.py
Normal file
15
src/utils/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from .common import (
|
||||
load_pretrained,
|
||||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data
|
||||
)
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
|
||||
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
|
||||
from .ppo import PPOTrainerForLLaMA
|
||||
|
||||
from .config import ModelArguments
|
||||
from .other import auto_configure_device_map, get_logits_processor, plot_loss
|
459
src/utils/common.py
Normal file
459
src/utils/common.py
Normal file
@ -0,0 +1,459 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import hashlib
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from .config import (
|
||||
ModelArguments,
|
||||
DataTrainingArguments,
|
||||
FinetuningArguments
|
||||
)
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
load_trainable_params,
|
||||
load_valuehead_params,
|
||||
print_trainable_params,
|
||||
prepare_model_for_training,
|
||||
IGNORE_INDEX,
|
||||
FINETUNING_ARGS_NAME
|
||||
)
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: PreTrainedModel,
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
is_trainable: bool
|
||||
) -> PreTrainedModel:
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
Support full-parameter, freeze and LoRA training.
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze":
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
lastest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
for checkpoint in checkpoints_to_merge:
|
||||
model = PeftModel.from_pretrained(model, checkpoint)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if lastest_checkpoint is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True)
|
||||
|
||||
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=finetuning_args.lora_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: Optional[FinetuningArguments] = None,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and (model_args.checkpoint_dir is None):
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint
|
||||
for checkpoint_dir in model_args.checkpoint_dir:
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)):
|
||||
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
|
||||
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
||||
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
|
||||
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side="left"
|
||||
)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
config_kwargs = {}
|
||||
if model_args.quantization_bit is not None:
|
||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||
|
||||
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
|
||||
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization."
|
||||
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() # cast all params to float16 for inference
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
load_valuehead_params(model, model_args.reward_model)
|
||||
|
||||
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
|
||||
# To meet the compliance requirements of the transformers library
|
||||
if model_args.quantization_bit is not None:
|
||||
model._is_int8_training_enabled = True
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_args(
|
||||
stage: Literal["sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
if stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_predict and (not training_args.predict_with_generate):
|
||||
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_data(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataTrainingArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
|
||||
elif dataset_attr.load_from == "script":
|
||||
raw_datasets = load_dataset(
|
||||
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
checksum(data_file, dataset_attr.file_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension,
|
||||
data_files=data_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
all_datasets = all_datasets[0]
|
||||
else:
|
||||
all_datasets = concatenate_datasets(all_datasets)
|
||||
|
||||
return all_datasets
|
||||
|
||||
|
||||
def preprocess_data(
|
||||
dataset: Dataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataTrainingArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
def format_example(examples): # support question with a single answer or multiple answers
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += examples["query"][i]
|
||||
prompt = "Below is an instruction that describes a task. "
|
||||
prompt += "Write a response that appropriately completes the request.\n"
|
||||
prompt += "Instruction:\n" + prefix
|
||||
if examples["history"][i]:
|
||||
history = examples["history"][i]
|
||||
for old_query, response in history:
|
||||
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
|
||||
prompt += "Human: {}\nAssistant: ".format(query)
|
||||
yield prompt, answer
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
labels = [IGNORE_INDEX] * len(source_ids) + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_evaluation_dataset(examples):
|
||||
# build inputs with format `X <s>` and labels with format `Y <s>`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id]
|
||||
labels = target_ids + [tokenizer.bos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `X <s> Y1 </s>` and `X <s> Y2 </s>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_sft_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
||||
|
||||
def print_ppo_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
|
||||
if stage == "sft":
|
||||
if (not training_args.do_train) and training_args.predict_with_generate: # with generation
|
||||
preprocess_function = preprocess_evaluation_dataset
|
||||
else: # without generation
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_evaluation_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "sft":
|
||||
print_sft_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_ppo_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
212
src/utils/config.py
Normal file
212
src/utils/config.py
Normal file
@ -0,0 +1,212 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
self.response_column = "output"
|
||||
self.history_column = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_zh",
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data",
|
||||
metadata={"help": "The name of the folder containing datasets."}
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||
)
|
||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
|
||||
def __post_init__(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r"))
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str):
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
if self.name_module_trainable == "mlp":
|
||||
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
|
||||
elif self.name_module_trainable == "qkv":
|
||||
self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Save the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Create an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
67
src/utils/data_collator.py
Normal file
67
src/utils/data_collator.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from .other import IGNORE_INDEX
|
||||
|
||||
|
||||
class DataCollatorForLLaMA(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for LLaMA. It is capable of dynamically padding for batched data.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
ignore_pad_token_for_loss: Optional[bool] = False
|
||||
):
|
||||
super().__init__(tokenizer, padding=True)
|
||||
self.model = model
|
||||
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
|
||||
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
r"""
|
||||
Generates attention masks for left-padded sequences.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.size()
|
||||
attention_mask = torch.ones((batch_size, seq_length), device=device)
|
||||
for i, seq in enumerate(input_ids):
|
||||
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
|
||||
attention_mask = attention_mask.bool()
|
||||
return attention_mask
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We adopt left-padding in both training and evaluation.
|
||||
"""
|
||||
if isinstance(features[0]["input_ids"], torch.Tensor):
|
||||
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
|
||||
else:
|
||||
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
|
||||
|
||||
if "labels" in features[0]:
|
||||
if isinstance(features[0]["labels"], torch.Tensor):
|
||||
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
|
||||
else:
|
||||
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
|
||||
input_ids = input_ids + labels # pad them to the same length
|
||||
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
|
||||
|
||||
batch = {}
|
||||
|
||||
if "labels" in features[0]:
|
||||
input_ids, labels = input_ids.split(len(features), dim=0)
|
||||
labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
|
||||
batch["labels"] = labels
|
||||
|
||||
batch["input_ids"] = input_ids
|
||||
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
|
||||
|
||||
return batch
|
205
src/utils/other.py
Normal file
205
src/utils/other.py
Normal file
@ -0,0 +1,205 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from peft.utils.other import WEIGHTS_NAME
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting
|
||||
) -> PreTrainedModel:
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
model.enable_input_require_grads()
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
state_dict = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights."
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
|
||||
|
||||
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||
r"""
|
||||
Configures device map for LLaMA.
|
||||
|
||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
|
||||
"""
|
||||
num_layers = 28
|
||||
layers_per_gpu = 30 / num_gpus
|
||||
device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0}
|
||||
added_layers = 2
|
||||
target_gpu = 0
|
||||
|
||||
for i in range(num_layers):
|
||||
if added_layers >= layers_per_gpu:
|
||||
target_gpu += 1
|
||||
added_layers = 0
|
||||
assert target_gpu < num_gpus
|
||||
device_map[f"model.layers.{i}"] = target_gpu
|
||||
added_layers += 1
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
||||
"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
last = smoothed_val
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
for i in range(len(data["log_history"])):
|
||||
if key in data["log_history"][i]:
|
||||
steps.append(data["log_history"][i]["step"])
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, training_args.output_dir))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))
|
51
src/utils/pairwise.py
Normal file
51
src/utils/pairwise.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
from .other import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
class PairwiseTrainerForLLaMA(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
We use score on the EOS token to represent reward of the whole sentence.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
outputs = {"r_accept": r_accept, "r_reject": r_reject}
|
||||
return (loss, outputs) if return_outputs else loss
|
78
src/utils/peft_trainer.py
Normal file
78
src/utils/peft_trainer.py
Normal file
@ -0,0 +1,78 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
from peft.utils.other import WEIGHTS_NAME
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
get_state_dict,
|
||||
load_trainable_params,
|
||||
load_valuehead_params,
|
||||
FINETUNING_ARGS_NAME,
|
||||
VALUE_HEAD_FILE_NAME
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Saves trainable parameters as model checkpoint.
|
||||
|
||||
This function will only be executed at the process zero.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
else:
|
||||
backbone_model = model
|
||||
|
||||
if hasattr(backbone_model, "peft_config"): # peft methods
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
|
||||
else:
|
||||
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
|
||||
|
||||
if hasattr(model, "v_head"): # save valuehead weights
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
r"""
|
||||
Loads trainable parameters from model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
if hasattr(model, "peft_config"): # peft methods
|
||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||
else:
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
if hasattr(model, "v_head"):
|
||||
load_valuehead_params(model, self.state.best_model_checkpoint)
|
241
src/utils/ppo.py
Normal file
241
src/utils/ppo.py
Normal file
@ -0,0 +1,241 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.trainer import TrainerState
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
from .other import (
|
||||
AverageMeter,
|
||||
get_logger,
|
||||
get_logits_processor
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save original head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
|
||||
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
})
|
||||
|
||||
|
||||
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.state = TrainerState()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
|
||||
|
||||
for _ in range(self.config.gradient_accumulation_steps):
|
||||
|
||||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
|
||||
# Get response from LLaMA
|
||||
query_tensors: torch.Tensor = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
|
||||
|
||||
queries: List[torch.Tensor] = []
|
||||
responses: List[torch.Tensor] = []
|
||||
for i in range(len(query_tensors)):
|
||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
||||
if response_length < 2: # make response have at least 2 tokens
|
||||
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
|
||||
else:
|
||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
||||
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
rewards = [reward for reward in values[:, -1]]
|
||||
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
|
||||
|
||||
# Run PPO step
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
|
||||
stats = self.step(queries, responses, rewards)
|
||||
|
||||
loss_meter.update(stats["ppo/loss/total"])
|
||||
reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards))
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
||||
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"loss": round(loss_meter.avg, 4),
|
||||
"reward": round(reward_meter.avg, 4),
|
||||
"learning_rate": stats["ppo/learning_rate"],
|
||||
"epoch": round(step / num_steps_per_epoch, 2)
|
||||
}
|
||||
print(logs)
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Callable = None,
|
||||
return_prompt: bool = True,
|
||||
**generation_kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||
|
||||
if not return_prompt and not self.is_encoder_decoder:
|
||||
return response[:, inputs["input_ids"].size(1):]
|
||||
return response
|
||||
|
||||
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
|
||||
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
|
||||
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
|
||||
input_data.pop("labels", None) # we don't want to compute LM losses
|
||||
return input_data
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
bs = len(model_inputs["input_ids"])
|
||||
fbs = self.config.mini_batch_size
|
||||
all_logprobs = []
|
||||
all_logits = []
|
||||
all_masks = []
|
||||
all_values = []
|
||||
|
||||
for i in range(int(bs / fbs)):
|
||||
input_kwargs = {k: v[i * fbs : (i + 1) * fbs] for k, v in model_inputs.items()}
|
||||
input_ids: torch.Tensor = input_kwargs["input_ids"] # left-padded sequences
|
||||
logits, _, values = model(**input_kwargs)
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
|
||||
masks = torch.zeros_like(input_ids)
|
||||
for j in range(fbs):
|
||||
start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item()
|
||||
masks[j][start:] = 1
|
||||
if len(masks[j][start:]) < 2:
|
||||
raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
|
||||
|
||||
all_logits.append(logits)
|
||||
all_values.append(values)
|
||||
all_logprobs.append(logprobs)
|
||||
all_masks.append(masks)
|
||||
|
||||
return (
|
||||
torch.cat(all_logprobs),
|
||||
torch.cat(all_logits)[:, :-1],
|
||||
torch.cat(all_values)[:, :-1],
|
||||
torch.cat(all_masks)[:, :-1],
|
||||
)
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
96
src/utils/seq2seq.py
Normal file
96
src/utils/seq2seq.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
from .other import get_logger, IGNORE_INDEX
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
|
||||
|
||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
for pred, label in zip(preds, labels):
|
||||
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
|
||||
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
||||
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
||||
|
||||
if len(" ".join(hypothesis).split()) == 0:
|
||||
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||
else:
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
||||
result = scores[0]
|
||||
|
||||
for k, v in result.items():
|
||||
score_dict[k].append(round(v["f"] * 100, 4))
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
|
||||
class Seq2SeqTrainerForLLaMA(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput,
|
||||
tokenizer: PreTrainedTokenizer
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||
|
||||
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
|
||||
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
|
||||
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for pred, label in zip(preds, labels):
|
||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||
writer.write("\n".join(res))
|
129
src/web_demo.py
Normal file
129
src/web_demo.py
Normal file
@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Implements user interface in browser for LLaMA fine-tuned with PEFT.
|
||||
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
import torch
|
||||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count())
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
|
||||
"""Override Chatbot.postprocess"""
|
||||
|
||||
def postprocess(self, y):
|
||||
if y is None:
|
||||
return []
|
||||
for i, (message, response) in enumerate(y):
|
||||
y[i] = (
|
||||
None if message is None else mdtex2html.convert((message)),
|
||||
None if response is None else mdtex2html.convert(response),
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
|
||||
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split('`')
|
||||
if count % 2 == 1:
|
||||
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
||||
else:
|
||||
lines[i] = f'<br></code></pre>'
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", "\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br>"+line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
|
||||
def predict(input, chatbot, max_length, top_p, temperature, history):
|
||||
chatbot.append((parse_text(input), ""))
|
||||
|
||||
inputs = tokenizer([input], return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"num_beams": 1,
|
||||
"max_length": max_length,
|
||||
"repetition_penalty": 1.0
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(**inputs, **gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
history = history + [(input, response)]
|
||||
chatbot[-1] = (parse_text(input), parse_text(response))
|
||||
yield chatbot, history
|
||||
|
||||
|
||||
def reset_user_input():
|
||||
return gr.update(value='')
|
||||
|
||||
|
||||
def reset_state():
|
||||
return [], []
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=12):
|
||||
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
||||
container=False)
|
||||
with gr.Column(min_width=32, scale=1):
|
||||
submitBtn = gr.Button("Submit", variant="primary")
|
||||
with gr.Column(scale=1):
|
||||
emptyBtn = gr.Button("Clear History")
|
||||
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
||||
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
||||
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
|
||||
show_progress=True)
|
||||
submitBtn.click(reset_user_input, [], [user_input])
|
||||
|
||||
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
Loading…
x
Reference in New Issue
Block a user