mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge branch 'main' of https://github.com/zhaonx/LLaMA-Factory into dev
Former-commit-id: 1abd55dd5987266280fef279c0a562f6c3e4835e
This commit is contained in:
commit
0c6c50f9b5
@ -11,4 +11,4 @@ RUN pip install -e .[deepspeed,metrics,bitsandbytes,qwen]
|
||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||
EXPOSE 7860
|
||||
|
||||
CMD [ "python", "src/train_web.py" ]
|
||||
CMD [ "llamafactory-cli webui" ]
|
||||
|
24
README.md
24
README.md
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -339,16 +339,17 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
### Train with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#command-line-interface) for distributed training.
|
||||
> LLaMA Board GUI only supports training on a single GPU, please use [CLI](#train-with-command-line-interface) for distributed training.
|
||||
|
||||
#### Use local environment
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # `set CUDA_VISIBLE_DEVICES=0` for Windows
|
||||
export GRADIO_SERVER_PORT=7860 # `set GRADIO_SERVER_PORT=7860` for Windows
|
||||
python src/train_web.py # or python -m llmtuner.webui.interface
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> To modify the default setting in the LLaMA Board GUI, you can use environment variables, e.g., `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False` (use `set` command on Windows OS).
|
||||
|
||||
<details><summary>For Alibaba Cloud users</summary>
|
||||
|
||||
If you encountered display problems in LLaMA Board on Alibaba Cloud, try using the following command to set environment variables before starting LLaMA Board:
|
||||
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
|
||||
|
||||
See [examples/README.md](examples/README.md) for usage.
|
||||
|
||||
Use `python src/train_bash.py -h` to display arguments description.
|
||||
> [!TIP]
|
||||
> Use `llamafactory-cli train -h` to display arguments description.
|
||||
|
||||
### Deploy with OpenAI-style API and vLLM
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
@ -441,6 +443,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
@ -448,7 +451,14 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
|
24
README_zh.md
24
README_zh.md
@ -5,7 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -339,16 +339,17 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
### 利用 LLaMA Board 可视化界面训练(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#命令行接口)来进行多 GPU 分布式训练。
|
||||
> LLaMA Board 可视化界面目前仅支持单 GPU 训练,请使用[命令行接口](#利用命令行接口训练)来进行多 GPU 分布式训练。
|
||||
|
||||
#### 使用本地环境
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 # Windows 使用 `set CUDA_VISIBLE_DEVICES=0`
|
||||
export GRADIO_SERVER_PORT=7860 # Windows 使用 `set GRADIO_SERVER_PORT=7860`
|
||||
python src/train_web.py # 或 python -m llmtuner.webui.interface
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 您可以使用环境变量来修改 LLaMA Board 可视化界面的默认设置,例如 `export CUDA_VISIBLE_DEVICES=0 GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_PORT=7860 GRADIO_SHARE=False`(Windows 系统可使用 `set` 指令)。
|
||||
|
||||
<details><summary>阿里云用户指南</summary>
|
||||
|
||||
如果您在阿里云上使用 LLaMA Board 时遇到显示问题,请尝试在启动前使用以下命令设置环境变量:
|
||||
@ -392,12 +393,13 @@ docker compose -f ./docker-compose.yml up -d
|
||||
|
||||
使用方法请参考 [examples/README_zh.md](examples/README_zh.md)。
|
||||
|
||||
您可以执行 `python src/train_bash.py -h` 来查看参数文档。
|
||||
> [!TIP]
|
||||
> 您可以执行 `llamafactory-cli train -h` 来查看参数文档。
|
||||
|
||||
### 利用 vLLM 部署 OpenAI API
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 python src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--template llama3 \
|
||||
--infer_backend vllm \
|
||||
@ -441,6 +443,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
@ -448,7 +451,14 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 186 KiB After Width: | Height: | Size: 123 KiB |
103
data/README.md
103
data/README.md
@ -1,4 +1,4 @@
|
||||
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
|
||||
If you are using a custom dataset, please add your **dataset description** to `dataset_info.json` according to the following format. We also provide several examples in the next section.
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
@ -33,7 +33,7 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
}
|
||||
```
|
||||
|
||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||
After that, you can load the custom dataset by specifying `--dataset dataset_name`.
|
||||
|
||||
----
|
||||
|
||||
@ -54,10 +54,11 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
@ -70,28 +71,60 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
For the **pre-training datasets**, only the `prompt` column will be used for training, for example:
|
||||
|
||||
```json
|
||||
{
|
||||
[
|
||||
{"text": "document"},
|
||||
{"text": "document"}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For the **preference datasets**, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Remember to set `"ranking": true` for the preference datasets.
|
||||
|
||||
----
|
||||
|
||||
The dataset in sharegpt format should follow the below format:
|
||||
The dataset in **sharegpt** format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
@ -112,10 +145,12 @@ The dataset in sharegpt format should follow the below format:
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
@ -132,4 +167,46 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
where the `messages` column should be a list following the `u/a/u/a/u/a` order.
|
||||
|
||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||
We also supports the dataset in the **openai** format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "system prompt (optional)"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "user instruction"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "model response"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant",
|
||||
"system_tag": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Pre-training datasets and preference datasets are **incompatible** with the sharegpt format yet.
|
||||
|
@ -1,4 +1,4 @@
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
|
||||
如果您使用自定义数据集,请务必按照以下格式在 `dataset_info.json` 文件中添加**数据集描述**。我们在下面也提供了一些例子。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
@ -33,7 +33,7 @@
|
||||
}
|
||||
```
|
||||
|
||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||
然后,可通过使用 `--dataset 数据集名称` 参数加载自定义数据集。
|
||||
|
||||
----
|
||||
|
||||
@ -54,10 +54,11 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
@ -70,28 +71,60 @@
|
||||
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**也会被用于训练**。
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
对于**预训练数据集**,仅 `prompt` 列中的内容会用于模型训练,例如:
|
||||
|
||||
```json
|
||||
{
|
||||
[
|
||||
{"text": "document"},
|
||||
{"text": "document"}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"columns": {
|
||||
"prompt": "text"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
对于**偏好数据集**,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"ranking": true,
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
添加偏好数据集需要额外指定 `"ranking": true`。
|
||||
|
||||
----
|
||||
|
||||
而 sharegpt 格式的数据集按照以下方式组织:
|
||||
而 **sharegpt** 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
@ -112,10 +145,12 @@
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"system": "system",
|
||||
@ -132,4 +167,46 @@
|
||||
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
|
||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||
我们同样支持 **openai** 格式的数据集:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "系统提示词(选填)"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "用户指令"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "模型回答"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的描述应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"file_name": "data.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant",
|
||||
"system_tag": "system"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
预训练数据集和偏好数据集**尚不支持** sharegpt 格式。
|
||||
|
@ -133,25 +133,19 @@ class Ceval(datasets.GeneratorBasedBuilder):
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "test", f"{task_name}_test.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "test", f"{task_name}_test.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "val", f"{task_name}_val.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "val", f"{task_name}_val.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "dev", f"{task_name}_dev.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -37,73 +37,73 @@ _LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 Internatio
|
||||
_URL = "cmmlu.zip"
|
||||
|
||||
task_list = [
|
||||
'agronomy',
|
||||
'anatomy',
|
||||
'ancient_chinese',
|
||||
'arts',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'chinese_civil_service_exam',
|
||||
'chinese_driving_rule',
|
||||
'chinese_food_culture',
|
||||
'chinese_foreign_policy',
|
||||
'chinese_history',
|
||||
'chinese_literature',
|
||||
'chinese_teacher_qualification',
|
||||
'clinical_knowledge',
|
||||
'college_actuarial_science',
|
||||
'college_education',
|
||||
'college_engineering_hydrology',
|
||||
'college_law',
|
||||
'college_mathematics',
|
||||
'college_medical_statistics',
|
||||
'college_medicine',
|
||||
'computer_science',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'construction_project_management',
|
||||
'economics',
|
||||
'education',
|
||||
'electrical_engineering',
|
||||
'elementary_chinese',
|
||||
'elementary_commonsense',
|
||||
'elementary_information_and_technology',
|
||||
'elementary_mathematics',
|
||||
'ethnology',
|
||||
'food_science',
|
||||
'genetics',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_geography',
|
||||
'high_school_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_politics',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'journalism',
|
||||
'jurisprudence',
|
||||
'legal_and_moral_basis',
|
||||
'logical',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'marxist_theory',
|
||||
'modern_chinese',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_study',
|
||||
'sociology',
|
||||
'sports_science',
|
||||
'traditional_chinese_medicine',
|
||||
'virology',
|
||||
'world_history',
|
||||
'world_religions',
|
||||
"agronomy",
|
||||
"anatomy",
|
||||
"ancient_chinese",
|
||||
"arts",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"chinese_civil_service_exam",
|
||||
"chinese_driving_rule",
|
||||
"chinese_food_culture",
|
||||
"chinese_foreign_policy",
|
||||
"chinese_history",
|
||||
"chinese_literature",
|
||||
"chinese_teacher_qualification",
|
||||
"clinical_knowledge",
|
||||
"college_actuarial_science",
|
||||
"college_education",
|
||||
"college_engineering_hydrology",
|
||||
"college_law",
|
||||
"college_mathematics",
|
||||
"college_medical_statistics",
|
||||
"college_medicine",
|
||||
"computer_science",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"construction_project_management",
|
||||
"economics",
|
||||
"education",
|
||||
"electrical_engineering",
|
||||
"elementary_chinese",
|
||||
"elementary_commonsense",
|
||||
"elementary_information_and_technology",
|
||||
"elementary_mathematics",
|
||||
"ethnology",
|
||||
"food_science",
|
||||
"genetics",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_geography",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_politics",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"journalism",
|
||||
"jurisprudence",
|
||||
"legal_and_moral_basis",
|
||||
"logical",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"marxist_theory",
|
||||
"modern_chinese",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_study",
|
||||
"sociology",
|
||||
"sports_science",
|
||||
"traditional_chinese_medicine",
|
||||
"virology",
|
||||
"world_history",
|
||||
"world_religions",
|
||||
]
|
||||
|
||||
|
||||
|
@ -136,25 +136,19 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "test", f"{task_name}_test.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "test", f"{task_name}_test.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "val", f"{task_name}_val.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "val", f"{task_name}_val.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
"filepath": os.path.join(data_dir, "data", "dev", f"{task_name}_dev.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
@ -10,7 +10,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
--finetuning_type full \
|
||||
--use_badam \
|
||||
--badam_switch_mode descending \
|
||||
--badam_switch_block_every 50 \
|
||||
--badam_switch_interval 50 \
|
||||
--badam_verbose 2 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/badam/sft \
|
||||
--overwrite_cache \
|
||||
|
@ -7,7 +7,7 @@ pip install "bitsandbytes>=0.43.0"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||
--config_file ../../accelerate/fsdp_config.yaml \
|
||||
../../../src/train_bash.py \
|
||||
../../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-70b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path ../../../models/llama2-7b-pro \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -6,7 +6,7 @@ python -m torch.distributed.run \
|
||||
--node_rank $RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path ../../saves/LLaMA2-7B/full/sft \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
deepspeed --num_gpus 4 ../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 python ../../src/api_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 API_PORT=8000 llamafactory-cli api \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/cli_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/evaluate.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template fewshot \
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
# add `--visual_inputs True` to load MLLM
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
# ZeRO-3 enables weight sharding on multiple GPUs
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
deepspeed --num_gpus 4 ../../src/train.py \
|
||||
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/master_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||
--config_file ../accelerate/single_config.yaml \
|
||||
../../src/train_bash.py \
|
||||
../../src/train.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage dpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage orpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage ppo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
# use `--tokenized_path` in training script to load data
|
||||
|
||||
CUDA_VISIBLE_DEVICES= python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES= llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage pt \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage rm \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path llava-hf/llava-1.5-7b-hf \
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
# DO NOT use quantized model or quantization_bit when merging lora weights
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--template default \
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
# NEED TO run `merge.sh` before using this script
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
|
||||
--model_name_or_path ../../models/llama2-7b-sft \
|
||||
--template default \
|
||||
--export_dir ../../models/llama2-7b-sft-int4 \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||
|
@ -16,3 +16,4 @@ sse-starlette
|
||||
matplotlib
|
||||
fire
|
||||
packaging
|
||||
pyyaml
|
||||
|
@ -3,24 +3,22 @@
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.chat import ChatModel
|
||||
|
||||
|
||||
def calculate_flops(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
flash_attn: Optional[bool] = False,
|
||||
batch_size: int = 1,
|
||||
seq_length: int = 256,
|
||||
flash_attn: str = "auto",
|
||||
):
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
|
||||
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
|
||||
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@ -25,12 +25,12 @@ BASE_BS = 4_000_000 # from llama paper
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: Optional[str] = "sft",
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||
stage: Literal["pt", "sft"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
@ -54,9 +54,7 @@ def calculate_lr(
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||
)
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
for batch in tqdm(dataloader):
|
||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||
|
116
scripts/cal_ppl.py
Normal file
116
scripts/cal_ppl.py
Normal file
@ -0,0 +1,116 @@
|
||||
# coding=utf-8
|
||||
# Calculates the ppl on the dataset of the pre-trained models.
|
||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model, load_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
|
||||
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
|
||||
attention_mask = [1] * (prompt_len + answer_len)
|
||||
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
|
||||
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
|
||||
|
||||
return super().__call__(chosen_features)
|
||||
|
||||
|
||||
def cal_ppl(
|
||||
model_name_or_path: str,
|
||||
save_name: str,
|
||||
batch_size: int = 4,
|
||||
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024,
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage=stage,
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
max_samples=max_samples,
|
||||
train_on_prompt=train_on_prompt,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
elif stage == "rm":
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
total_ppl = 0
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
|
||||
loss_mask = shift_labels != IGNORE_INDEX
|
||||
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
|
||||
flatten_labels = shift_labels.contiguous().view(-1)
|
||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
total_ppl += sentence_logps.exp().sum().item()
|
||||
perplexities.extend(sentence_logps.exp().tolist())
|
||||
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(cal_ppl)
|
@ -3,7 +3,6 @@
|
||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
@ -15,10 +14,10 @@ from llmtuner.model import load_tokenizer
|
||||
|
||||
def length_cdf(
|
||||
model_name_or_path: str,
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
interval: Optional[int] = 1000,
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
interval: int = 1000,
|
||||
):
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
|
1
setup.py
1
setup.py
@ -52,6 +52,7 @@ def main():
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
entry_points={"console_scripts": ["llamafactory-cli = llmtuner.cli:main"]},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
|
@ -1,16 +0,0 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel, create_app
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8000)))
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,49 +0,0 @@
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
|
||||
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,9 +0,0 @@
|
||||
from llmtuner import Evaluator
|
||||
|
||||
|
||||
def main():
|
||||
Evaluator().eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,9 +0,0 @@
|
||||
from llmtuner import export_model
|
||||
|
||||
|
||||
def main():
|
||||
export_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,11 +1,3 @@
|
||||
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
||||
|
||||
from .api import create_app
|
||||
from .chat import ChatModel
|
||||
from .eval import Evaluator
|
||||
from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.7.0"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
__version__ = "0.7.1.dev0"
|
||||
|
@ -1,4 +0,0 @@
|
||||
from .app import create_app
|
||||
|
||||
|
||||
__all__ = ["create_app"]
|
@ -1,36 +1,29 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
create_chat_completion_response,
|
||||
create_score_evaluation_response,
|
||||
create_stream_chat_completion_response,
|
||||
)
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
Role,
|
||||
ScoreEvaluationRequest,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
if is_starlette_available():
|
||||
@ -47,23 +40,8 @@ async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@ -71,162 +49,58 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
@app.get(
|
||||
"/v1/models",
|
||||
response_model=ModelList,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=ChatCompletionResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if not chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
if request.stream:
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
generate = stream_chat_completion(input_messages, system, tools, request)
|
||||
generate = create_stream_chat_completion_response(request, chat_model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
stop=request.stop
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
return await create_chat_completion_response(request, chat_model)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
response_message = ChatCompletionMessage(
|
||||
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||
@app.post(
|
||||
"/v1/score/evaluation",
|
||||
response_model=ScoreEvaluationResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||
)
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def stream_chat_completion(
|
||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||
):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
async for new_token in chat_model.astream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
stop=request.stop
|
||||
):
|
||||
if len(new_token) == 0:
|
||||
continue
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(content=new_token), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield jsonify(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_score_evaluation(request: ScoreEvaluationRequest):
|
||||
if chat_model.engine.can_generate:
|
||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
return await create_score_evaluation_response(request, chat_model)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
177
src/llmtuner/api/chat.py
Normal file
177
src/llmtuner/api/chat.py
Normal file
@ -0,0 +1,177 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.packages import is_fastapi_availble
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseUsage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionStreamResponseChoice,
|
||||
Finish,
|
||||
Function,
|
||||
FunctionCall,
|
||||
Role,
|
||||
ScoreEvaluationResponse,
|
||||
)
|
||||
|
||||
|
||||
if is_fastapi_availble():
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..chat import ChatModel
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
|
||||
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
|
||||
if len(request.messages) % 2 == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = ""
|
||||
|
||||
return input_messages, system, tools
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
completion_id: str,
|
||||
model: str,
|
||||
delta: "ChatCompletionMessage",
|
||||
index: Optional[int] = 0,
|
||||
finish_reason: Optional["Finish"] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
|
||||
return jsonify(chunk)
|
||||
|
||||
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||
|
||||
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
|
||||
prompt_length = response.prompt_length
|
||||
response_length += response.response_length
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
prompt_tokens=prompt_length,
|
||||
completion_tokens=response_length,
|
||||
total_tokens=prompt_length + response_length,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
|
||||
|
||||
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
|
||||
)
|
||||
async for new_token in chat_model.astream_chat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
):
|
||||
if len(new_token) != 0:
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
|
||||
)
|
||||
|
||||
yield _create_stream_chat_completion_chunk(
|
||||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||
)
|
||||
yield "[DONE]"
|
||||
|
||||
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
20
src/llmtuner/api/common.py
Normal file
20
src/llmtuner/api/common.py
Normal file
@ -0,0 +1,20 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def dictify(data: "BaseModel") -> Dict[str, Any]:
|
||||
try: # pydantic v2
|
||||
return data.model_dump(exclude_unset=True)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.dict(exclude_unset=True)
|
||||
|
||||
|
||||
def jsonify(data: "BaseModel") -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except AttributeError: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
@ -51,7 +51,7 @@ class FunctionAvailable(BaseModel):
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
id: Literal["call_default"] = "call_default"
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
@ -87,7 +87,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
finish_reason: Finish
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
@ -100,7 +100,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
@ -109,11 +109,11 @@ class ChatCompletionResponse(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: Literal["chatcmpl-default"] = "chatcmpl-default"
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
@ -123,7 +123,7 @@ class ScoreEvaluationRequest(BaseModel):
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
id: Literal["scoreeval-default"] = "scoreeval-default"
|
||||
id: str
|
||||
object: Literal["score.evaluation"] = "score.evaluation"
|
||||
model: str
|
||||
scores: List[float]
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import get_infer_args
|
||||
from .hf_engine import HuggingfaceEngine
|
||||
from .vllm_engine import VllmEngine
|
||||
@ -95,3 +96,45 @@ class ChatModel:
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
torch_gc()
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": query})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in chat_model.stream_chat(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
print()
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
59
src/llmtuner/cli.py
Normal file
59
src/llmtuner/cli.py
Normal file
@ -0,0 +1,59 @@
|
||||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import __version__
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
|
||||
USAGE = """
|
||||
Usage:
|
||||
llamafactory-cli api -h: launch an API server
|
||||
llamafactory-cli chat -h: launch a chat interface in CLI
|
||||
llamafactory-cli eval -h: do evaluation
|
||||
llamafactory-cli export -h: merge LoRA adapters and export model
|
||||
llamafactory-cli train -h: do training
|
||||
llamafactory-cli webchat -h: launch a chat interface in Web UI
|
||||
llamafactory-cli webui: launch LlamaBoard
|
||||
llamafactory-cli version: show version info
|
||||
"""
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
API = "api"
|
||||
CHAT = "chat"
|
||||
EVAL = "eval"
|
||||
EXPORT = "export"
|
||||
TRAIN = "train"
|
||||
WEBDEMO = "webchat"
|
||||
WEBUI = "webui"
|
||||
VERSION = "version"
|
||||
HELP = "help"
|
||||
|
||||
|
||||
def main():
|
||||
command = sys.argv.pop(1)
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
run_chat()
|
||||
elif command == Command.EVAL:
|
||||
run_eval()
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
elif command == Command.WEBUI:
|
||||
run_web_ui()
|
||||
elif command == Command.VERSION:
|
||||
print("Welcome to LLaMA Factory, version {}".format(__version__))
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}".format(command))
|
@ -1,4 +0,0 @@
|
||||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
__all__ = ["Evaluator"]
|
@ -118,6 +118,5 @@ class Evaluator:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
def run_eval() -> None:
|
||||
Evaluator().eval()
|
||||
|
@ -1,14 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
|
||||
from .constants import LOG_FILE_NAME
|
||||
from .logging import get_logger
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
|
||||
|
||||
@ -33,57 +38,92 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.in_training = False
|
||||
self.start_time = time.time()
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def timing(self):
|
||||
def _set_abort(self, signum, frame) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _reset(self, max_steps: int = 0) -> None:
|
||||
self.start_time = time.time()
|
||||
self.cur_steps = 0
|
||||
self.max_steps = max_steps
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def _timing(self, cur_steps: int) -> None:
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
|
||||
self.cur_steps = cur_steps
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def _create_thread_pool(self, output_dir: str) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _close_thread_pool(self) -> None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
"""
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
|
||||
if args.save_on_each_node:
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
else:
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = False
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@ -91,10 +131,7 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.cur_steps = state.global_step
|
||||
self.timing()
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
if self.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@ -102,31 +139,22 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_predict(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
|
||||
):
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if args.save_on_each_node:
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
else:
|
||||
if not state.is_world_process_zero:
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
self._timing(cur_steps=state.global_step)
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
@ -141,16 +169,16 @@ class LogCallback(TrainerCallback):
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
if self.runner is not None:
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"]
|
||||
)
|
||||
)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
@ -158,9 +186,28 @@ class LogCallback(TrainerCallback):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
if self.do_train:
|
||||
return
|
||||
|
||||
if self.aborted:
|
||||
sys.exit(0)
|
||||
|
||||
if not args.should_save:
|
||||
return
|
||||
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||
if has_length(eval_dataloader):
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
self.cur_steps += 1
|
||||
self.timing()
|
||||
self._reset(max_steps=len(eval_dataloader))
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
@ -24,8 +24,6 @@ IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MLLM_LIST = ["LLaVA1.5"]
|
||||
@ -34,10 +32,16 @@ MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral
|
||||
|
||||
PEFT_METHODS = ["lora"]
|
||||
|
||||
RUNNING_LOG = "running_log.txt"
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINER_CONFIG = "trainer_config.yaml"
|
||||
|
||||
TRAINER_LOG = "trainer_log.jsonl"
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
|
@ -1,5 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
@ -7,19 +11,35 @@ class LoggerHandler(logging.Handler):
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
self.log = ""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
self.setLevel(logging.INFO)
|
||||
self.setFormatter(formatter)
|
||||
|
||||
def reset(self):
|
||||
self.log = ""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||
if os.path.exists(self.running_log):
|
||||
os.remove(self.running_log)
|
||||
|
||||
def emit(self, record):
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _write_log(self, log_entry: str) -> None:
|
||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||
f.write(log_entry + "\n\n")
|
||||
|
||||
def emit(self, record) -> None:
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
log_entry = self.format(record)
|
||||
self.log += log_entry
|
||||
self.log += "\n\n"
|
||||
self.thread_pool.submit(self._write_log, log_entry)
|
||||
|
||||
def close(self) -> None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
return super().close()
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
smoothed = []
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
for log in trainer_log:
|
||||
if log.get("loss", None):
|
||||
steps.append(log["current_steps"])
|
||||
losses.append(log["loss"])
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
@ -221,16 +221,18 @@ class BAdamArgument:
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_block_every: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
},
|
||||
)
|
||||
badam_update_ratio: float = field(
|
||||
default=0.0,
|
||||
default=0.05,
|
||||
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
||||
)
|
||||
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||
@ -308,6 +310,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
|
@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import TRAINER_CONFIG
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
@ -251,7 +252,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
files = os.listdir(training_args.output_dir)
|
||||
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
|
@ -1,4 +0,0 @@
|
||||
from .tuner import export_model, run_exp
|
||||
|
||||
|
||||
__all__ = ["export_model", "run_exp"]
|
@ -165,13 +165,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
@ -113,15 +113,15 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
|
||||
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
|
||||
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
|
||||
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
|
||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
|
||||
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().cpu().mean()
|
||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().mean().cpu()
|
||||
metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().mean().cpu()
|
||||
metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().mean().cpu()
|
||||
metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().mean().cpu()
|
||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
|
||||
metrics["{}odds_ratio_loss".format(prefix)] = odds_ratio_loss.detach().mean().cpu()
|
||||
|
||||
return batch_loss, metrics
|
||||
|
@ -23,9 +23,9 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
callbacks.append(LogCallback(training_args.output_dir))
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
@ -43,7 +43,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
if model_args.export_dir is None:
|
||||
@ -88,7 +88,3 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||
except Exception:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_exp()
|
||||
|
@ -317,14 +317,14 @@ def _create_badam_optimizer(
|
||||
base_optimizer=base_optimizer,
|
||||
named_parameters_list=list(model.named_parameters()),
|
||||
block_prefix_list=None,
|
||||
switch_block_every=finetuning_args.badam_switch_block_every,
|
||||
switch_block_every=finetuning_args.badam_switch_interval,
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
|
||||
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
||||
f"default start block is {finetuning_args.badam_start_block}"
|
||||
)
|
||||
|
||||
|
@ -1,4 +0,0 @@
|
||||
from .interface import create_ui, create_web_demo
|
||||
|
||||
|
||||
__all__ = ["create_ui", "create_web_demo"]
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from yaml import safe_dump, safe_load
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
@ -16,6 +17,7 @@ from ..extras.constants import (
|
||||
TRAINING_STAGES,
|
||||
DownloadSource,
|
||||
)
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.packages import is_gradio_available
|
||||
|
||||
@ -24,12 +26,15 @@ if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
@ -47,7 +52,7 @@ def get_save_path(config_path: str) -> os.PathLike:
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
@ -60,13 +65,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(user_config, f)
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -74,7 +79,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
return str(get_save_path(config_path))
|
||||
|
||||
@ -127,11 +132,15 @@ def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
if dataset_dir == "ONLINE":
|
||||
logger.info("dataset_dir is ONLINE, using online dataset.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as err:
|
||||
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||
return {}
|
||||
|
||||
|
||||
|
@ -36,9 +36,9 @@ def create_chat_box(
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
max_new_tokens = gr.Slider(8, 4096, value=512, step=1)
|
||||
top_p = gr.Slider(0.01, 1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
@ -21,25 +21,25 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
|
||||
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||
@ -52,19 +52,19 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
output_elems = [output_box, progress_bar]
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
)
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train import export_model
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
@ -85,7 +85,7 @@ def save_model(
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
export_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu")
|
||||
|
@ -27,7 +27,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
@ -52,10 +52,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
cutoff_len = gr.Slider(minimum=4, maximum=65536, value=1024, step=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
val_size = gr.Slider(minimum=0, maximum=1, value=0, step=0.001)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
|
||||
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
|
||||
@ -71,10 +71,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
logging_steps = gr.Slider(minimum=1, maximum=1000, value=5, step=5)
|
||||
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
|
||||
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
|
||||
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
|
||||
optim = gr.Textbox(value="adamw_torch")
|
||||
|
||||
with gr.Row():
|
||||
@ -124,7 +124,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
|
||||
num_layer_trainable = gr.Slider(minimum=1, maximum=128, value=2, step=1)
|
||||
name_module_trainable = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
@ -136,10 +136,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||
lora_rank = gr.Slider(minimum=1, maximum=1024, value=8, step=1)
|
||||
lora_alpha = gr.Slider(minimum=1, maximum=2048, value=16, step=1)
|
||||
lora_dropout = gr.Slider(minimum=0, maximum=1, value=0, step=0.01)
|
||||
loraplus_lr_ratio = gr.Slider(minimum=0, maximum=64, value=0, step=0.01)
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
@ -180,9 +180,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
|
||||
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
|
||||
orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
|
||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||
@ -193,9 +193,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox()
|
||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
|
||||
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
|
||||
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
|
||||
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
galore_target = gr.Textbox(value="all")
|
||||
|
||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||
@ -210,6 +210,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as badam_tab:
|
||||
with gr.Row():
|
||||
use_badam = gr.Checkbox()
|
||||
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
|
||||
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
|
||||
badam_switch_interval = gr.Slider(minimum=1, maximum=1024, value=50, step=1)
|
||||
badam_update_ratio = gr.Slider(minimum=0, maximum=1, value=0.05, step=0.01)
|
||||
|
||||
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
badam_tab=badam_tab,
|
||||
use_badam=use_badam,
|
||||
badam_mode=badam_mode,
|
||||
badam_switch_mode=badam_switch_mode,
|
||||
badam_switch_interval=badam_switch_interval,
|
||||
badam_update_ratio=badam_update_ratio,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
arg_save_btn = gr.Button()
|
||||
@ -225,7 +245,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
@ -243,14 +263,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, process_bar, loss_viewer]
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
|
@ -41,7 +41,7 @@ class Engine:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
@ -51,7 +51,7 @@ class Engine:
|
||||
|
||||
yield self._update_component(init_dict)
|
||||
|
||||
if self.runner.alive and not self.demo_mode and not self.pure_chat:
|
||||
if self.runner.running and not self.demo_mode and not self.pure_chat:
|
||||
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._update_component({"train.resume_btn": {"value": True}})
|
||||
|
@ -68,5 +68,9 @@ def create_web_demo() -> gr.Blocks:
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
def run_web_ui() -> None:
|
||||
create_ui().queue().launch()
|
||||
|
||||
|
||||
def run_web_demo() -> None:
|
||||
create_web_demo().queue().launch()
|
||||
|
@ -891,6 +891,87 @@ LOCALES = {
|
||||
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"badam_tab": {
|
||||
"en": {
|
||||
"label": "BAdam configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Конфигурации BAdam",
|
||||
},
|
||||
"zh": {
|
||||
"label": "BAdam 参数设置",
|
||||
},
|
||||
},
|
||||
"use_badam": {
|
||||
"en": {
|
||||
"label": "Use BAdam",
|
||||
"info": "Enable the BAdam optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать BAdam",
|
||||
"info": "Включите оптимизатор BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 BAdam",
|
||||
"info": "使用 BAdam 优化器。",
|
||||
},
|
||||
},
|
||||
"badam_mode": {
|
||||
"en": {
|
||||
"label": "BAdam mode",
|
||||
"info": "Whether to use layer-wise or ratio-wise BAdam optimizer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Режим BAdam",
|
||||
"info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "BAdam 模式",
|
||||
"info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
|
||||
},
|
||||
},
|
||||
"badam_switch_mode": {
|
||||
"en": {
|
||||
"label": "Switch mode",
|
||||
"info": "The strategy of picking block to update for layer-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Режим переключения",
|
||||
"info": "Стратегия выбора блока для обновления для послойного BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "切换策略",
|
||||
"info": "Layer-wise BAdam 优化器的块切换策略。",
|
||||
},
|
||||
},
|
||||
"badam_switch_interval": {
|
||||
"en": {
|
||||
"label": "Switch interval",
|
||||
"info": "Number of steps to update the block for layer-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Интервал переключения",
|
||||
"info": "количество шагов для обновления блока для пошагового BAdam.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "切换频率",
|
||||
"info": "Layer-wise BAdam 优化器的块切换频率。",
|
||||
},
|
||||
},
|
||||
"badam_update_ratio": {
|
||||
"en": {
|
||||
"label": "Update ratio",
|
||||
"info": "The ratio of the update for ratio-wise BAdam.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Коэффициент обновления",
|
||||
"info": "Коэффициент обновления для BAdam с учётом соотношений.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "Block 更新比例",
|
||||
"info": "Ratio-wise BAdam 优化器的更新比例。",
|
||||
},
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview command",
|
||||
@ -1368,7 +1449,7 @@ ALERTS = {
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"ru": "Прервано, ожидание завершения...",
|
||||
"zh": "训练中断,正在等待线程结束……",
|
||||
"zh": "训练中断,正在等待进程结束……",
|
||||
},
|
||||
"info_aborted": {
|
||||
"en": "Ready.",
|
||||
|
@ -1,22 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
import transformers
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@ -34,24 +31,18 @@ class Runner:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
||||
os.kill(children.pid, signal.SIGABRT)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
@ -85,13 +76,11 @@ class Runner:
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.thread = None
|
||||
self.trainer = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
@ -147,12 +136,12 @@ class Runner:
|
||||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
use_badam=get("train.use_badam"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
@ -198,6 +187,12 @@ class Runner:
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
if args["use_badam"]:
|
||||
args["badam_mode"] = get("train.badam_mode")
|
||||
args["badam_switch_mode"] = get("train.badam_switch_mode")
|
||||
args["badam_switch_interval"] = get("train.badam_switch_interval")
|
||||
args["badam_update_ratio"] = get("train.badam_update_ratio")
|
||||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
@ -237,7 +232,6 @@ class Runner:
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if get("eval.predict"):
|
||||
args["do_predict"] = True
|
||||
@ -263,11 +257,12 @@ class Runner:
|
||||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
env = deepcopy(os.environ)
|
||||
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data):
|
||||
@ -283,10 +278,10 @@ class Runner:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
@ -294,28 +289,31 @@ class Runner:
|
||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
while self.trainer is not None:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
|
||||
return_dict = {
|
||||
output_box: self.logger_handler.log,
|
||||
process_bar: update_process_bar(self.trainer_callback),
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
if running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
time.sleep(2)
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
@ -330,16 +328,11 @@ class Runner:
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info),
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
def save_args(self, data: dict):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
|
@ -1,10 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from yaml import safe_dump
|
||||
|
||||
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@ -12,30 +15,6 @@ if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type != "lora":
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
@ -57,14 +36,18 @@ def check_json_schema(text: str, lang: str) -> None:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||
for k, v in args.items():
|
||||
if v is not None and v is not False and v != "":
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)]
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
return cmd_text
|
||||
@ -76,29 +59,49 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
|
||||
log_file = os.path.join(output_path, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file) or not is_matplotlib_available():
|
||||
return
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
|
||||
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, "r", encoding="utf-8") as f:
|
||||
running_log = f.read()
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
with open(trainer_log_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
log_info: Dict[str, Any] = json.loads(line)
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
trainer_log.append(json.loads(line))
|
||||
|
||||
if len(losses) == 0:
|
||||
return
|
||||
if len(trainer_log) != 0:
|
||||
latest_log = trainer_log[-1]
|
||||
percentage = latest_log["percentage"]
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
latest_log["current_steps"],
|
||||
latest_log["total_steps"],
|
||||
latest_log["elapsed_time"],
|
||||
latest_log["remaining_time"],
|
||||
)
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
if do_train and is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||
safe_dump(clean_cmd(args), f)
|
||||
|
||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from llmtuner import run_exp
|
||||
from llmtuner.train.tuner import run_exp
|
||||
|
||||
|
||||
def main():
|
||||
@ -7,7 +7,7 @@ def main():
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
run_exp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -1,9 +0,0 @@
|
||||
from llmtuner import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,9 +0,0 @@
|
||||
from llmtuner import create_web_demo
|
||||
|
||||
|
||||
def main():
|
||||
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user