diff --git a/.dockerignore b/.dockerignore
index 2ac0e11d..23ad75a8 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -4,10 +4,10 @@
.venv
cache
data
+docker
+saves
hf_cache
output
-examples
.dockerignore
.gitattributes
.gitignore
-Dockerfile
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 1d962200..768adea6 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -38,7 +38,9 @@ body:
请合理使用 Markdown 标签来格式化您的文本。
placeholder: |
+ ```bash
llamafactory-cli train ...
+ ```
- type: textarea
id: expected-behavior
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index b31e9d19..d23d6be3 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -5,3 +5,4 @@ Fixes # (issue)
## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
+- [ ] Did you write any new necessary tests?
diff --git a/.github/workflows/label_issue.yml b/.github/workflows/label_issue.yml
new file mode 100644
index 00000000..ffd644a7
--- /dev/null
+++ b/.github/workflows/label_issue.yml
@@ -0,0 +1,27 @@
+name: label_issue
+
+on:
+ issues:
+ types:
+ - opened
+
+jobs:
+ label_issue:
+ runs-on: ubuntu-latest
+
+ steps:
+ - env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ ISSUE_URL: ${{ github.event.issue.html_url }}
+ ISSUE_TITLE: ${{ github.event.issue.title }}
+ run: |
+ LABEL=pending
+ NPU_KEYWORDS=(npu ascend huawei 华为 昇腾)
+ ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
+ for KEYWORD in ${NPU_KEYWORDS[@]}; do
+ if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
+ LABEL=pending,npu
+ break
+ fi
+ done
+ gh issue edit $ISSUE_URL --add-label $LABEL
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 00000000..15c7153e
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,40 @@
+name: publish
+
+on:
+ release:
+ types:
+ - published
+
+jobs:
+ publish:
+ name: Upload release to PyPI
+
+ runs-on: ubuntu-latest
+
+ environment:
+ name: release
+ url: https://pypi.org/p/llamafactory
+
+ permissions:
+ id-token: write
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.8"
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install build
+
+ - name: Build package
+ run: |
+ python -m build
+
+ - name: Publish package
+ uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 32edf6a8..73d77de5 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -19,21 +19,27 @@ on:
jobs:
tests:
runs-on: ubuntu-latest
+
steps:
- - uses: actions/checkout@v4
+ - name: Checkout
+ uses: actions/checkout@v4
+
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
cache: "pip"
cache-dependency-path: "setup.py"
+
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- python -m pip install .[torch,dev]
+ python -m pip install ".[torch,dev]"
+
- name: Check quality
run: |
make style && make quality
+
- name: Test with pytest
run: |
make test
diff --git a/.gitignore b/.gitignore
index 0355c666..82e6e9e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,6 +160,8 @@ cython_debug/
.idea/
# custom .gitignore
-user.config
-saves/
cache/
+config/
+saves/
+output/
+wandb/
diff --git a/CITATION.cff b/CITATION.cff
index 4caf3787..01b4c9fd 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -12,12 +12,16 @@ authors:
given-names: "Yanhan"
- family-names: "Luo"
given-names: "Zheyan"
+- family-names: "Feng"
+ given-names: "Zhangchi"
- family-names: "Ma"
given-names: "Yongqiang"
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
url: "https://arxiv.org/abs/2403.13372"
preferred-citation:
- type: article
+ type: conference-paper
+ conference:
+ name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)"
authors:
- family-names: "Zheng"
given-names: "Yaowei"
@@ -29,9 +33,12 @@ preferred-citation:
given-names: "Yanhan"
- family-names: "Luo"
given-names: "Zheyan"
+ - family-names: "Feng"
+ given-names: "Zhangchi"
- family-names: "Ma"
given-names: "Yongqiang"
- journal: "arXiv preprint arXiv:2403.13372"
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
url: "https://arxiv.org/abs/2403.13372"
year: 2024
+ publisher: "Association for Computational Linguistics"
+ address: "Bangkok, Thailand"
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index 0a35e355..00000000
--- a/Dockerfile
+++ /dev/null
@@ -1,14 +0,0 @@
-FROM nvcr.io/nvidia/pytorch:24.01-py3
-
-WORKDIR /app
-
-COPY requirements.txt /app/
-RUN pip install -r requirements.txt
-
-COPY . /app/
-RUN pip install -e .[metrics,bitsandbytes,qwen]
-
-VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
-EXPOSE 7860
-
-CMD [ "llamafactory-cli", "webui" ]
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..82c51f63
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE requirements.txt
diff --git a/Makefile b/Makefile
index 65be047b..3f13b215 100644
--- a/Makefile
+++ b/Makefile
@@ -11,4 +11,4 @@ style:
ruff format $(check_dirs)
test:
- pytest tests/
+ CUDA_VISIBLE_DEVICES= pytest tests/
diff --git a/README.md b/README.md
index fb6c5782..3d3feae5 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[](LICENSE)
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[](https://pypi.org/project/llamafactory/)
-[](#projects-using-llama-factory)
+[](#projects-using-llama-factory)
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
[](https://discord.gg/rKfvV9r9FK)
[](https://twitter.com/llamafactory_ai)
@@ -15,7 +15,7 @@
[](https://trendshift.io/repositories/4535)
-👋 Join our [WeChat](assets/wechat.jpg).
+👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
\[ English | [中文](README_zh.md) \]
@@ -48,8 +48,8 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
-- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
-- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
+- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
+- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
-[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models.
+[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
-[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models.
+[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -151,35 +151,32 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models
-| Model | Model size | Template |
-| -------------------------------------------------------- | -------------------------------- | --------- |
-| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
-| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
-| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
-| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
-| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
-| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
-| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
-| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
-| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
-| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
-| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
-| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
-| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
-| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
-| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
-| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
-| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
-| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
-| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
-| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
-| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
-| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
-| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
-| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
-| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+| Model | Model size | Template |
+| ------------------------------------------------------------ | -------------------------------- | --------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
+| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
+| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -259,6 +256,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -335,10 +335,10 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
-pip install -e '.[torch,metrics]'
+pip install -e ".[torch,metrics]"
```
-Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
+Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
> [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts.
@@ -357,9 +357,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
For Ascend NPU users
-Join [NPU user group](assets/wechat_npu.jpg).
-
-To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
+To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
```bash
# replace the url according to your CANN version and devices
@@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
-Docker image:
-
-- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
-- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
-
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
+Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
### Data Preparation
@@ -405,9 +400,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@@ -417,34 +412,89 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
-#### Use local environment
-
```bash
-CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
+llamafactory-cli webui
```
-
+### Build Docker
-#### Use Docker
+For CUDA users:
```bash
-docker build -f ./Dockerfile -t llama-factory:latest .
-docker run --gpus=all \
- -v ./hf_cache:/root/.cache/huggingface/ \
+cd docker/docker-cuda/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+For Ascend NPU users:
+
+```bash
+cd docker/docker-npu/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+Build without Docker Compose
+
+For CUDA users:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg INSTALL_BNB=false \
+ --build-arg INSTALL_VLLM=false \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg INSTALL_FLASHATTN=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+docker run -dit --gpus=all \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \
+ -p 8000:8000 \
--shm-size 16G \
- --name llama_factory \
- -d llama-factory:latest
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
-#### Use Docker Compose
+For Ascend NPU users:
```bash
-docker compose -f ./docker-compose.yml up -d
+# Choose docker image upon your environment
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+# Change `device` upon your resources
+docker run -dit \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
+
+
Details about volume
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
@@ -456,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
### Deploy with OpenAI-style API and vLLM
```bash
-CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
+API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
@@ -474,7 +524,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam
### Use W&B Logger
-To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
+To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
```yaml
report_to: wandb
@@ -494,38 +544,63 @@ If you have a project that should be incorporated, please contact via email or c
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
-1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
-1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
-1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
-1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
-1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
-1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
-1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
-1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
@@ -533,6 +608,8 @@ If you have a project that should be incorporated, please contact via email or c
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
@@ -540,17 +617,19 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
-Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation
If this work is helpful, please kindly cite as:
```bibtex
-@article{zheng2024llamafactory,
+@inproceedings{zheng2024llamafactory,
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
- author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
- journal={arXiv preprint arXiv:2403.13372},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
year={2024},
url={http://arxiv.org/abs/2403.13372}
}
diff --git a/README_zh.md b/README_zh.md
index 142254df..cb5a42e4 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -4,7 +4,7 @@
[](LICENSE)
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[](https://pypi.org/project/llamafactory/)
-[](#使用了-llama-factory-的项目)
+[](#使用了-llama-factory-的项目)
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
[](https://discord.gg/rKfvV9r9FK)
[](https://twitter.com/llamafactory_ai)
@@ -15,7 +15,7 @@
[](https://trendshift.io/repositories/4535)
-👋 加入我们的[微信群](assets/wechat.jpg)。
+👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -48,8 +48,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
-- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
-- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
+- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
+- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
-[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。
+[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
-[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。
+[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -151,35 +151,32 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型
-| 模型名 | 模型大小 | Template |
-| -------------------------------------------------------- | -------------------------------- | --------- |
-| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
-| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
-| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
-| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
-| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
-| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
-| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
-| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
-| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
-| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
-| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
-| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
-| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
-| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
-| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
-| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
-| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
-| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
-| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
-| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
-| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
-| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
-| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
-| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
-| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
-| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
+| 模型名 | 模型大小 | Template |
+| ------------------------------------------------------------ | -------------------------------- | --------- |
+| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
+| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
+| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
+| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
+| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
+| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
+| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
+| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
+| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
+| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
+| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
+| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
+| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
+| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
+| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
+| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
+| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
+| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
+| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
+| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
+| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
+| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
+| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
+| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE]
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
@@ -259,6 +256,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
+- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
+- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
+- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -335,10 +335,10 @@ huggingface-cli login
```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
-pip install -e '.[torch,metrics]'
+pip install -e ".[torch,metrics]"
```
-可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
+可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
> [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@@ -357,9 +357,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
昇腾 NPU 用户指南
-加入 [NPU 用户群](assets/wechat_npu.jpg)。
-
-在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
+在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
```bash
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
@@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
-Docker 镜像:
-
-- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
-- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
-
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
+下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
+
### 数据准备
@@ -405,9 +400,9 @@ Docker 镜像:
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@@ -417,32 +412,89 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
-#### 使用本地环境
-
```bash
-CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui
+llamafactory-cli webui
```
-#### 使用 Docker
+### 构建 Docker
+
+CUDA 用户:
```bash
-docker build -f ./Dockerfile -t llama-factory:latest .
-docker run --gpus=all \
- -v ./hf_cache:/root/.cache/huggingface/ \
+cd docker/docker-cuda/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+昇腾 NPU 用户:
+
+```bash
+cd docker/docker-npu/
+docker-compose up -d
+docker-compose exec llamafactory bash
+```
+
+不使用 Docker Compose 构建
+
+CUDA 用户:
+
+```bash
+docker build -f ./docker/docker-cuda/Dockerfile \
+ --build-arg INSTALL_BNB=false \
+ --build-arg INSTALL_VLLM=false \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg INSTALL_FLASHATTN=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+docker run -dit --gpus=all \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
-v ./data:/app/data \
-v ./output:/app/output \
-p 7860:7860 \
+ -p 8000:8000 \
--shm-size 16G \
- --name llama_factory \
- -d llama-factory:latest
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
-#### 使用 Docker Compose
+昇腾 NPU 用户:
```bash
-docker compose -f ./docker-compose.yml up -d
+# 根据您的环境选择镜像
+docker build -f ./docker/docker-npu/Dockerfile \
+ --build-arg INSTALL_DEEPSPEED=false \
+ --build-arg PIP_INDEX=https://pypi.org/simple \
+ -t llamafactory:latest .
+
+# 根据您的资源更改 `device`
+docker run -dit \
+ -v ./hf_cache:/root/.cache/huggingface \
+ -v ./ms_cache:/root/.cache/modelscope \
+ -v ./data:/app/data \
+ -v ./output:/app/output \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -p 7860:7860 \
+ -p 8000:8000 \
+ --device /dev/davinci0 \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --shm-size 16G \
+ --name llamafactory \
+ llamafactory:latest
+
+docker exec -it llamafactory bash
```
+
+
数据卷详情
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
@@ -454,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d
### 利用 vLLM 部署 OpenAI API
```bash
-CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
+API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
```
> [!TIP]
@@ -472,7 +524,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
### 使用 W&B 面板
-若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
+若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
```yaml
report_to: wandb
@@ -492,38 +544,63 @@ run_name: test_run # 可选
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
-1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
-1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
+1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
+1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
-1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
+1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
-1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
+1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
-1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
+1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
-1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
+1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
-1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
-1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
+1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
+1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
+1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
+1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
+1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
+1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
+1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
+1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
+1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
+1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
+1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
+1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
+1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
+1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
+1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
+1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
+1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
+1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
+1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
+1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
+1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
+1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
+1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
+1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
+1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
+1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
@@ -531,6 +608,8 @@ run_name: test_run # 可选
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
+1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
+1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: 在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
@@ -538,17 +617,19 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
-使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
-@article{zheng2024llamafactory,
- title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
- author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
- journal={arXiv preprint arXiv:2403.13372},
+@inproceedings{zheng2024llamafactory,
+ title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
+ author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
+ booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+ address={Bangkok, Thailand},
+ publisher={Association for Computational Linguistics},
year={2024},
url={http://arxiv.org/abs/2403.13372}
}
diff --git a/docker-compose.yml b/docker-compose.yml
deleted file mode 100644
index 9602a3e3..00000000
--- a/docker-compose.yml
+++ /dev/null
@@ -1,23 +0,0 @@
-version: '3.8'
-
-services:
- llama-factory:
- build:
- dockerfile: Dockerfile
- context: .
- container_name: llama_factory
- volumes:
- - ./hf_cache:/root/.cache/huggingface/
- - ./data:/app/data
- - ./output:/app/output
- ports:
- - "7860:7860"
- ipc: host
- deploy:
- resources:
- reservations:
- devices:
- - driver: nvidia
- count: "all"
- capabilities: [gpu]
- restart: unless-stopped
diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile
new file mode 100644
index 00000000..d94aa970
--- /dev/null
+++ b/docker/docker-cuda/Dockerfile
@@ -0,0 +1,58 @@
+# Use the NVIDIA official image with PyTorch 2.3.0
+# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
+FROM nvcr.io/nvidia/pytorch:24.02-py3
+
+# Define environments
+ENV MAX_JOBS=4
+ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
+
+# Define installation arguments
+ARG INSTALL_BNB=false
+ARG INSTALL_VLLM=false
+ARG INSTALL_DEEPSPEED=false
+ARG INSTALL_FLASHATTN=false
+ARG PIP_INDEX=https://pypi.org/simple
+
+# Set the working directory
+WORKDIR /app
+
+# Install the requirements
+COPY requirements.txt /app
+RUN pip config set global.index-url "$PIP_INDEX" && \
+ pip config set global.extra-index-url "$PIP_INDEX" && \
+ python -m pip install --upgrade pip && \
+ python -m pip install -r requirements.txt
+
+# Rebuild flash attention
+RUN pip uninstall -y transformer-engine flash-attn && \
+ if [ "$INSTALL_FLASHATTN" == "true" ]; then \
+ pip uninstall -y ninja && pip install ninja && \
+ pip install --no-cache-dir flash-attn --no-build-isolation; \
+ fi
+
+# Copy the rest of the application into the image
+COPY . /app
+
+# Install the LLaMA Factory
+RUN EXTRA_PACKAGES="metrics"; \
+ if [ "$INSTALL_BNB" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
+ fi; \
+ if [ "$INSTALL_VLLM" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
+ fi; \
+ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
+ fi; \
+ pip install -e ".[$EXTRA_PACKAGES]"
+
+# Set up volumes
+VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
+
+# Expose port 7860 for the LLaMA Board
+ENV GRADIO_SERVER_PORT 7860
+EXPOSE 7860
+
+# Expose port 8000 for the API service
+ENV API_PORT 8000
+EXPOSE 8000
diff --git a/docker/docker-cuda/docker-compose.yml b/docker/docker-cuda/docker-compose.yml
new file mode 100644
index 00000000..16267dc3
--- /dev/null
+++ b/docker/docker-cuda/docker-compose.yml
@@ -0,0 +1,32 @@
+services:
+ llamafactory:
+ build:
+ dockerfile: ./docker/docker-cuda/Dockerfile
+ context: ../..
+ args:
+ INSTALL_BNB: false
+ INSTALL_VLLM: false
+ INSTALL_DEEPSPEED: false
+ INSTALL_FLASHATTN: false
+ PIP_INDEX: https://pypi.org/simple
+ container_name: llamafactory
+ volumes:
+ - ../../hf_cache:/root/.cache/huggingface
+ - ../../ms_cache:/root/.cache/modelscope
+ - ../../data:/app/data
+ - ../../output:/app/output
+ ports:
+ - "7860:7860"
+ - "8000:8000"
+ ipc: host
+ tty: true
+ stdin_open: true
+ command: bash
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: "all"
+ capabilities: [gpu]
+ restart: unless-stopped
diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile
new file mode 100644
index 00000000..34cf9616
--- /dev/null
+++ b/docker/docker-npu/Dockerfile
@@ -0,0 +1,45 @@
+# Use the Ubuntu 22.04 image with CANN 8.0.rc1
+# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
+# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04
+FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
+# FROM cosdt/cann:8.0.rc1-910-openeuler22.03
+# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03
+
+# Define environments
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Define installation arguments
+ARG INSTALL_DEEPSPEED=false
+ARG PIP_INDEX=https://pypi.org/simple
+ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
+
+# Set the working directory
+WORKDIR /app
+
+# Install the requirements
+COPY requirements.txt /app
+RUN pip config set global.index-url "$PIP_INDEX" && \
+ pip config set global.extra-index-url "$TORCH_INDEX" && \
+ python -m pip install --upgrade pip && \
+ python -m pip install -r requirements.txt
+
+# Copy the rest of the application into the image
+COPY . /app
+
+# Install the LLaMA Factory
+RUN EXTRA_PACKAGES="torch-npu,metrics"; \
+ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
+ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
+ fi; \
+ pip install -e ".[$EXTRA_PACKAGES]"
+
+# Set up volumes
+VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
+
+# Expose port 7860 for the LLaMA Board
+ENV GRADIO_SERVER_PORT 7860
+EXPOSE 7860
+
+# Expose port 8000 for the API service
+ENV API_PORT 8000
+EXPOSE 8000
diff --git a/docker/docker-npu/docker-compose.yml b/docker/docker-npu/docker-compose.yml
new file mode 100644
index 00000000..657cba9f
--- /dev/null
+++ b/docker/docker-npu/docker-compose.yml
@@ -0,0 +1,31 @@
+services:
+ llamafactory:
+ build:
+ dockerfile: ./docker/docker-npu/Dockerfile
+ context: ../..
+ args:
+ INSTALL_DEEPSPEED: false
+ PIP_INDEX: https://pypi.org/simple
+ container_name: llamafactory
+ volumes:
+ - ../../hf_cache:/root/.cache/huggingface
+ - ../../ms_cache:/root/.cache/modelscope
+ - ../../data:/app/data
+ - ../../output:/app/output
+ - /usr/local/dcmi:/usr/local/dcmi
+ - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
+ - /usr/local/Ascend/driver:/usr/local/Ascend/driver
+ - /etc/ascend_install.info:/etc/ascend_install.info
+ ports:
+ - "7860:7860"
+ - "8000:8000"
+ ipc: host
+ tty: true
+ stdin_open: true
+ command: bash
+ devices:
+ - /dev/davinci0
+ - /dev/davinci_manager
+ - /dev/devmm_svm
+ - /dev/hisi_hdc
+ restart: unless-stopped
diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py
index 4111d6b4..48442d50 100644
--- a/evaluation/ceval/ceval.py
+++ b/evaluation/ceval/ceval.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py
index 37efb328..5ff548a4 100644
--- a/evaluation/cmmlu/cmmlu.py
+++ b/evaluation/cmmlu/cmmlu.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py
index a4530250..1065fb31 100644
--- a/evaluation/mmlu/mmlu.py
+++ b/evaluation/mmlu/mmlu.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import datasets
diff --git a/examples/README.md b/examples/README.md
index f985d552..d5aca5ad 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
## Table of Contents
-- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu)
-- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu)
-- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus)
-- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
-- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
+- [LoRA Fine-Tuning](#lora-fine-tuning)
+- [QLoRA Fine-Tuning](#qlora-fine-tuning)
+- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
- [Extras](#extras)
+Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
+
## Examples
-### LoRA Fine-Tuning on A Single GPU
+### LoRA Fine-Tuning
#### (Continuous) Pre-Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
#### Supervised Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### Multimodal Supervised Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
```
#### Reward Modeling
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
```
#### Preprocess Dataset
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
+llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
+llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
-```
-
-### QLoRA Fine-Tuning on a Single GPU
-
-#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
-```
-
-#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
-```
-
-#### Supervised Fine-Tuning with 4-bit AWQ Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
-```
-
-#### Supervised Fine-Tuning with 2-bit AQLM Quantization
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
-```
-
-### LoRA Fine-Tuning on Multiple GPUs
-
-#### Supervised Fine-Tuning on Single Node
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
-### LoRA Fine-Tuning on Multiple NPUs
+### QLoRA Fine-Tuning
-#### Supervised Fine-Tuning with DeepSpeed ZeRO-0
+#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
```bash
-ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
-### Full-Parameter Fine-Tuning on Multiple GPUs
+#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
+```
+
+#### Supervised Fine-Tuning with 4-bit AWQ Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
+```
+
+#### Supervised Fine-Tuning with 2-bit AQLM Quantization
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
+```
+
+### Full-Parameter Fine-Tuning
#### Supervised Fine-Tuning on Single Node
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Batch Predicting and Computing BLEU and ROUGE Scores
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
+llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### Merging LoRA Adapters and Quantization
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
#### Quantizing Model using AutoGPTQ
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
+llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### Inferring LoRA Fine-Tuned Models
-Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
-
#### Use CLI
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### Use Web UI
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
```
#### Launch OpenAI-style API
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
+llamafactory-cli api examples/inference/llama3_lora_sft.yaml
```
### Extras
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### Full-Parameter Fine-Tuning using GaLore
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### Full-Parameter Fine-Tuning using BAdam
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### LoRA+ Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+```
+
+#### PiSSA Fine-Tuning
+
+```bash
+llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
```
#### Mixture-of-Depths Fine-Tuning
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
```
#### LLaMA-Pro Fine-Tuning
```bash
bash examples/extras/llama_pro/expand.sh
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
+llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```
#### FSDP+QLoRA Fine-Tuning
```bash
-bash examples/extras/fsdp_qlora/single_node.sh
+bash examples/extras/fsdp_qlora/train.sh
```
diff --git a/examples/README_zh.md b/examples/README_zh.md
index cf5bbf49..d96bf882 100644
--- a/examples/README_zh.md
+++ b/examples/README_zh.md
@@ -4,59 +4,59 @@
## 目录
-- [单 GPU LoRA 微调](#单-gpu-lora-微调)
-- [单 GPU QLoRA 微调](#单-gpu-qlora-微调)
-- [多 GPU LoRA 微调](#多-gpu-lora-微调)
-- [多 NPU LoRA 微调](#多-npu-lora-微调)
-- [多 GPU 全参数微调](#多-gpu-全参数微调)
+- [LoRA 微调](#lora-微调)
+- [QLoRA 微调](#qlora-微调)
+- [全参数微调](#全参数微调)
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
- [推理 LoRA 模型](#推理-lora-模型)
- [杂项](#杂项)
+使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
+
## 示例
-### 单 GPU LoRA 微调
+### LoRA 微调
#### (增量)预训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
```
#### 指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### 多模态指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml
+llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
```
#### 奖励模型训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
```
#### PPO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
```
#### DPO/ORPO/SimPO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
```
#### 预处理数据集
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml
+llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
```
#### 在 MMLU/CMMLU/C-Eval 上评估
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml
+llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml
+llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
-### 单 GPU QLoRA 微调
-
-#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
+#### 多机指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
-```
-
-#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
-```
-
-#### 基于 4 比特 AWQ 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
-```
-
-#### 基于 2 比特 AQLM 量化进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
-```
-
-### 多 GPU LoRA 微调
-
-#### 在单机上进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-```
-
-#### 在多机上进行指令监督微调
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```
#### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
-### 多 NPU LoRA 微调
+### QLoRA 微调
-#### 使用 DeepSpeed ZeRO-0 进行指令监督微调
+#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
```bash
-ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
```
-### 多 GPU 全参数微调
+#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
+```
+
+#### 基于 4 比特 AWQ 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
+```
+
+#### 基于 2 比特 AQLM 量化进行指令监督微调
+
+```bash
+llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
+```
+
+### 全参数微调
#### 在单机上进行指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 在多机上进行指令监督微调
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
-CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
+FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 批量预测并计算 BLEU 和 ROUGE 分数
```bash
-CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml
+llamafactory-cli train examples/train_full/llama3_full_predict.yaml
```
### 合并 LoRA 适配器与模型量化
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
+llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
```
#### 使用 AutoGPTQ 量化模型
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
+llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
```
### 推理 LoRA 模型
-使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
-
#### 使用命令行接口
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
```
#### 使用浏览器界面
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
+llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
```
#### 启动 OpenAI 风格 API
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml
+llamafactory-cli api examples/inference/llama3_lora_sft.yaml
```
### 杂项
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### 使用 GaLore 进行全参数训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
```
#### 使用 BAdam 进行全参数训练
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
```
#### LoRA+ 微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
+```
+
+#### PiSSA 微调
+
+```bash
+llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
```
#### 深度混合微调
```bash
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
+llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
```
#### LLaMA-Pro 微调
```bash
bash examples/extras/llama_pro/expand.sh
-CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
+llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
```
#### FSDP+QLoRA 微调
```bash
-bash examples/extras/fsdp_qlora/single_node.sh
+bash examples/extras/fsdp_qlora/train.sh
```
diff --git a/examples/full_multi_gpu/llama3_full_sft.yaml b/examples/extras/badam/llama3_full_sft.yaml
similarity index 81%
rename from examples/full_multi_gpu/llama3_full_sft.yaml
rename to examples/extras/badam/llama3_full_sft.yaml
index 40b62f24..31d61c33 100644
--- a/examples/full_multi_gpu/llama3_full_sft.yaml
+++ b/examples/extras/badam/llama3_full_sft.yaml
@@ -5,10 +5,11 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft
do_train: true
finetuning_type: full
-
-### ddp
-ddp_timeout: 180000000
-deepspeed: examples/deepspeed/ds_z3_config.json
+use_badam: true
+badam_mode: layer
+badam_switch_mode: ascending
+badam_switch_interval: 50
+badam_verbose: 2
### dataset
dataset: identity,alpaca_en_demo
@@ -27,12 +28,11 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
-gradient_accumulation_steps: 2
+gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
### eval
val_size: 0.1
diff --git a/examples/extras/badam/llama3_lora_sft.yaml b/examples/extras/badam/llama3_full_sft_ds3.yaml
similarity index 91%
rename from examples/extras/badam/llama3_lora_sft.yaml
rename to examples/extras/badam/llama3_full_sft_ds3.yaml
index a78de2fa..f2d7309f 100644
--- a/examples/extras/badam/llama3_lora_sft.yaml
+++ b/examples/extras/badam/llama3_full_sft_ds3.yaml
@@ -6,9 +6,11 @@ stage: sft
do_train: true
finetuning_type: full
use_badam: true
+badam_mode: layer
badam_switch_mode: ascending
badam_switch_interval: 50
badam_verbose: 2
+deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -32,7 +34,6 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-pure_bf16: true
### eval
val_size: 0.1
diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
index 084269ef..6c80ef58 100644
--- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
+++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
@@ -8,9 +8,6 @@ do_train: true
finetuning_type: lora
lora_target: all
-### ddp
-ddp_timeout: 180000000
-
### dataset
dataset: identity,alpaca_en_demo
template: llama3
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/fsdp_qlora/single_node.sh b/examples/extras/fsdp_qlora/train.sh
similarity index 100%
rename from examples/extras/fsdp_qlora/single_node.sh
rename to examples/extras/fsdp_qlora/train.sh
diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml
index 444a1113..5e7e90bb 100644
--- a/examples/extras/llama_pro/llama3_freeze_sft.yaml
+++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml
@@ -31,7 +31,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml
index 1ba654ec..062a312b 100644
--- a/examples/extras/loraplus/llama3_lora_sft.yaml
+++ b/examples/extras/loraplus/llama3_lora_sft.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml
index df03c1e0..085febfc 100644
--- a/examples/extras/mod/llama3_full_sft.yaml
+++ b/examples/extras/mod/llama3_full_sft.yaml
@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
pure_bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml b/examples/extras/pissa/llama3_lora_sft.yaml
similarity index 88%
rename from examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
rename to examples/extras/pissa/llama3_lora_sft.yaml
index b308dcab..05077b6c 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
+++ b/examples/extras/pissa/llama3_lora_sft.yaml
@@ -1,12 +1,14 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
-quantization_bit: 4
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
+pissa_init: true
+pissa_iter: 4
+pissa_convert: true
### dataset
dataset: identity,alpaca_en_demo
@@ -30,7 +32,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/full_multi_gpu/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml
similarity index 100%
rename from examples/full_multi_gpu/llama3_full_predict.yaml
rename to examples/train_full/llama3_full_predict.yaml
diff --git a/examples/lora_multi_gpu/llama3_lora_sft.yaml b/examples/train_full/llama3_full_sft_ds3.yaml
similarity index 83%
rename from examples/lora_multi_gpu/llama3_lora_sft.yaml
rename to examples/train_full/llama3_full_sft_ds3.yaml
index 348e53b9..c983ad5c 100644
--- a/examples/lora_multi_gpu/llama3_lora_sft.yaml
+++ b/examples/train_full/llama3_full_sft_ds3.yaml
@@ -4,11 +4,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
### method
stage: sft
do_train: true
-finetuning_type: lora
-lora_target: all
-
-### ddp
-ddp_timeout: 180000000
+finetuning_type: full
+deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_en_demo
@@ -19,7 +16,7 @@ overwrite_cache: true
preprocessing_num_workers: 16
### output
-output_dir: saves/llama3-8b/lora/sft
+output_dir: saves/llama3-8b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
@@ -32,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_dpo.yaml b/examples/train_lora/llama3_lora_dpo.yaml
similarity index 87%
rename from examples/lora_single_gpu/llama3_lora_dpo.yaml
rename to examples/train_lora/llama3_lora_dpo.yaml
index 78344330..d87c0669 100644
--- a/examples/lora_single_gpu/llama3_lora_dpo.yaml
+++ b/examples/train_lora/llama3_lora_dpo.yaml
@@ -7,7 +7,7 @@ do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
-pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
+pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: dpo_en_demo
@@ -31,7 +31,8 @@ learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_eval.yaml b/examples/train_lora/llama3_lora_eval.yaml
similarity index 100%
rename from examples/lora_single_gpu/llama3_lora_eval.yaml
rename to examples/train_lora/llama3_lora_eval.yaml
diff --git a/examples/lora_single_gpu/llama3_lora_kto.yaml b/examples/train_lora/llama3_lora_kto.yaml
similarity index 93%
rename from examples/lora_single_gpu/llama3_lora_kto.yaml
rename to examples/train_lora/llama3_lora_kto.yaml
index d5234c0a..08208c25 100644
--- a/examples/lora_single_gpu/llama3_lora_kto.yaml
+++ b/examples/train_lora/llama3_lora_kto.yaml
@@ -6,6 +6,7 @@ stage: kto
do_train: true
finetuning_type: lora
lora_target: all
+pref_beta: 0.1
### dataset
dataset: kto_en_demo
@@ -29,7 +30,8 @@ learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_ppo.yaml b/examples/train_lora/llama3_lora_ppo.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_ppo.yaml
rename to examples/train_lora/llama3_lora_ppo.yaml
index 98c842f9..512e90ea 100644
--- a/examples/lora_single_gpu/llama3_lora_ppo.yaml
+++ b/examples/train_lora/llama3_lora_ppo.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### generate
max_new_tokens: 512
diff --git a/examples/lora_single_gpu/llama3_lora_predict.yaml b/examples/train_lora/llama3_lora_predict.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_predict.yaml
rename to examples/train_lora/llama3_lora_predict.yaml
index a127d248..148c8635 100644
--- a/examples/lora_single_gpu/llama3_lora_predict.yaml
+++ b/examples/train_lora/llama3_lora_predict.yaml
@@ -22,3 +22,4 @@ overwrite_output_dir: true
### eval
per_device_eval_batch_size: 1
predict_with_generate: true
+ddp_timeout: 180000000
diff --git a/examples/lora_single_gpu/llama3_lora_pretrain.yaml b/examples/train_lora/llama3_lora_pretrain.yaml
similarity index 94%
rename from examples/lora_single_gpu/llama3_lora_pretrain.yaml
rename to examples/train_lora/llama3_lora_pretrain.yaml
index db435ca9..5e8aaaef 100644
--- a/examples/lora_single_gpu/llama3_lora_pretrain.yaml
+++ b/examples/train_lora/llama3_lora_pretrain.yaml
@@ -28,7 +28,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_reward.yaml b/examples/train_lora/llama3_lora_reward.yaml
similarity index 91%
rename from examples/lora_single_gpu/llama3_lora_reward.yaml
rename to examples/train_lora/llama3_lora_reward.yaml
index 1ce42ea4..96c32238 100644
--- a/examples/lora_single_gpu/llama3_lora_reward.yaml
+++ b/examples/train_lora/llama3_lora_reward.yaml
@@ -25,11 +25,12 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
-learning_rate: 1.0e-5
+learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml
similarity index 95%
rename from examples/lora_single_gpu/llama3_lora_sft.yaml
rename to examples/train_lora/llama3_lora_sft.yaml
index 651b636f..55a8077e 100644
--- a/examples/lora_single_gpu/llama3_lora_sft.yaml
+++ b/examples/train_lora/llama3_lora_sft.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds0.yaml
similarity index 97%
rename from examples/lora_multi_npu/llama3_lora_sft_ds.yaml
rename to examples/train_lora/llama3_lora_sft_ds0.yaml
index a0ec8aa1..f1442faa 100644
--- a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml
+++ b/examples/train_lora/llama3_lora_sft_ds0.yaml
@@ -6,9 +6,6 @@ stage: sft
do_train: true
finetuning_type: lora
lora_target: all
-
-### ddp
-ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json
### dataset
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds3.yaml
similarity index 97%
rename from examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
rename to examples/train_lora/llama3_lora_sft_ds3.yaml
index 1c432fa7..66e7007e 100644
--- a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml
+++ b/examples/train_lora/llama3_lora_sft_ds3.yaml
@@ -6,9 +6,6 @@ stage: sft
do_train: true
finetuning_type: lora
lora_target: all
-
-### ddp
-ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
@@ -33,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/lora_single_gpu/llama3_preprocess.yaml b/examples/train_lora/llama3_preprocess.yaml
similarity index 100%
rename from examples/lora_single_gpu/llama3_preprocess.yaml
rename to examples/train_lora/llama3_preprocess.yaml
diff --git a/examples/lora_single_gpu/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml
similarity index 95%
rename from examples/lora_single_gpu/llava1_5_lora_sft.yaml
rename to examples/train_lora/llava1_5_lora_sft.yaml
index df510a93..ec03f82c 100644
--- a/examples/lora_single_gpu/llava1_5_lora_sft.yaml
+++ b/examples/train_lora/llava1_5_lora_sft.yaml
@@ -30,7 +30,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
rename to examples/train_qlora/llama3_lora_sft_aqlm.yaml
index d54d6af6..3519d46b 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
+++ b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml b/examples/train_qlora/llama3_lora_sft_awq.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
rename to examples/train_qlora/llama3_lora_sft_awq.yaml
index 5cef178a..df48669b 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_awq.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml b/examples/train_qlora/llama3_lora_sft_gptq.yaml
similarity index 95%
rename from examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
rename to examples/train_qlora/llama3_lora_sft_gptq.yaml
index b950042e..61fa9bb4 100644
--- a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_gptq.yaml
@@ -29,7 +29,8 @@ learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
-fp16: true
+bf16: true
+ddp_timeout: 180000000
### eval
val_size: 0.1
diff --git a/examples/train_qlora/llama3_lora_sft_otfq.yaml b/examples/train_qlora/llama3_lora_sft_otfq.yaml
new file mode 100644
index 00000000..80a05768
--- /dev/null
+++ b/examples/train_qlora/llama3_lora_sft_otfq.yaml
@@ -0,0 +1,41 @@
+### model
+model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
+quantization_bit: 4
+quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
+
+### method
+stage: sft
+do_train: true
+finetuning_type: lora
+lora_target: all
+
+### dataset
+dataset: identity,alpaca_en_demo
+template: llama3
+cutoff_len: 1024
+max_samples: 1000
+overwrite_cache: true
+preprocessing_num_workers: 16
+
+### output
+output_dir: saves/llama3-8b/lora/sft
+logging_steps: 10
+save_steps: 500
+plot_loss: true
+overwrite_output_dir: true
+
+### train
+per_device_train_batch_size: 1
+gradient_accumulation_steps: 8
+learning_rate: 1.0e-4
+num_train_epochs: 3.0
+lr_scheduler_type: cosine
+warmup_ratio: 0.1
+bf16: true
+ddp_timeout: 180000000
+
+### eval
+val_size: 0.1
+per_device_eval_batch_size: 1
+eval_strategy: steps
+eval_steps: 500
diff --git a/requirements.txt b/requirements.txt
index 9e00555e..7380add4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,6 +4,7 @@ accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
gradio>=4.0.0
+pandas>=2.0.0
scipy
einops
sentencepiece
@@ -17,3 +18,4 @@ matplotlib>=3.7.0
fire
packaging
pyyaml
+numpy<2.0.0
diff --git a/scripts/cal_flops.py b/scripts/cal_flops.py
index ac87e0ab..32526d89 100644
--- a/scripts/cal_flops.py
+++ b/scripts/cal_flops.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Calculates the flops of pre-trained models.
-# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
-# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
+# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
+#
+# This code is inspired by the Microsoft's DeepSpeed library.
+# https://www.deepspeed.ai/tutorials/flops-profiler/
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import fire
import torch
@@ -17,6 +30,10 @@ def calculate_flops(
seq_length: int = 256,
flash_attn: str = "auto",
):
+ r"""
+ Calculates the flops of pre-trained models.
+ Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
+ """
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py
index bfa32cc9..ad6992cb 100644
--- a/scripts/cal_lr.py
+++ b/scripts/cal_lr.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
-# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
-# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
+# Copyright 2024 imoneoi and the LlamaFactory team.
+#
+# This code is inspired by the imoneoi's OpenChat library.
+# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import math
from typing import Literal
@@ -32,6 +45,10 @@ def calculate_lr(
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
):
+ r"""
+ Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
+ Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
+ """
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py
index 387b756c..fb503629 100644
--- a/scripts/cal_ppl.py
+++ b/scripts/cal_ppl.py
@@ -1,6 +1,17 @@
# coding=utf-8
-# Calculates the ppl on the dataset of the pre-trained models.
-# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
from dataclasses import dataclass
@@ -56,6 +67,10 @@ def cal_ppl(
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
+ r"""
+ Calculates the ppl on the dataset of the pre-trained models.
+ Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
+ """
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage=stage,
diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py
index 7739dcf0..4cdf01e6 100644
--- a/scripts/length_cdf.py
+++ b/scripts/length_cdf.py
@@ -1,6 +1,17 @@
# coding=utf-8
-# Calculates the distribution of the input lengths in the dataset.
-# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from collections import defaultdict
@@ -19,6 +30,10 @@ def length_cdf(
template: str = "default",
interval: int = 1000,
):
+ r"""
+ Calculates the distribution of the input lengths in the dataset.
+ Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
+ """
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage="sft",
diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py
index 727998ae..17bf6fc2 100644
--- a/scripts/llama_pro.py
+++ b/scripts/llama_pro.py
@@ -1,7 +1,20 @@
# coding=utf-8
-# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
-# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
-# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
+# Copyright 2024 Tencent Inc. and the LlamaFactory team.
+#
+# This code is inspired by the Tencent's LLaMA-Pro library.
+# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -37,6 +50,10 @@ def block_expansion(
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
):
+ r"""
+ Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
+ Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
+ """
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand)
@@ -103,7 +120,7 @@ def block_expansion(
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
- print("Fine-tune this model with:")
+ print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand))
diff --git a/scripts/llamafy_baichuan2.py b/scripts/llamafy_baichuan2.py
index 1ae58879..19284f5f 100644
--- a/scripts/llamafy_baichuan2.py
+++ b/scripts/llamafy_baichuan2.py
@@ -1,8 +1,17 @@
# coding=utf-8
-# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
-# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
-# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
-# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
+ r"""
+ Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
+ Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
+ Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
+ """
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
diff --git a/scripts/llamafy_qwen.py b/scripts/llamafy_qwen.py
index 69cf3e8e..e5b59483 100644
--- a/scripts/llamafy_qwen.py
+++ b/scripts/llamafy_qwen.py
@@ -1,7 +1,17 @@
# coding=utf-8
-# Converts the Qwen models in the same format as LLaMA2.
-# Usage: python llamafy_qwen.py --input_dir input --output_dir output
-# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
@@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
+ r"""
+ Converts the Qwen models in the same format as LLaMA2.
+ Usage: python llamafy_qwen.py --input_dir input --output_dir output
+ Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
+ """
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py
index 7f244316..4d2c01b9 100644
--- a/scripts/loftq_init.py
+++ b/scripts/loftq_init.py
@@ -1,14 +1,25 @@
# coding=utf-8
-# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
-# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
-# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is based on the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING
import fire
-import torch
-import torch.nn as nn
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -17,65 +28,61 @@ if TYPE_CHECKING:
from transformers import PreTrainedModel
-class Shell(nn.Module):
- def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
- super().__init__()
- self.weight = nn.Parameter(weight, requires_grad=False)
- if bias is not None:
- self.bias = nn.Parameter(bias, requires_grad=False)
-
-
-def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
- for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
- parent_name = ".".join(name.split(".")[:-1])
- child_name = name.split(".")[-1]
- parent_module = model.get_submodule(parent_name)
- child_module = getattr(parent_module, child_name)
- base_layer = getattr(child_module, "base_layer")
- weight = getattr(base_layer, "weight", None)
- bias = getattr(base_layer, "bias", None)
- setattr(parent_module, child_name, Shell(weight, bias))
-
- print("Model unwrapped.")
-
-
def quantize_loftq(
model_name_or_path: str,
- save_dir: str,
- loftq_bits: Optional[int] = 4,
- loftq_iter: Optional[int] = 1,
- lora_alpha: Optional[int] = None,
- lora_rank: Optional[int] = 16,
- lora_target: Optional[str] = "q_proj,v_proj",
- save_safetensors: Optional[bool] = False,
+ output_dir: str,
+ loftq_bits: int = 4,
+ loftq_iter: int = 4,
+ lora_alpha: int = None,
+ lora_rank: int = 16,
+ lora_dropout: float = 0,
+ lora_target: tuple = ("q_proj", "v_proj"),
+ save_safetensors: bool = True,
):
+ r"""
+ Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
+ Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
+ """
+ if isinstance(lora_target, str):
+ lora_target = [name.strip() for name in lora_target.split(",")]
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
+
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
- lora_dropout=0.1,
- target_modules=[name.strip() for name in lora_target.split(",")],
+ lora_dropout=lora_dropout,
+ target_modules=lora_target,
init_lora_weights="loftq",
loftq_config=loftq_config,
)
# Init LoftQ model
- lora_model = get_peft_model(model, lora_config)
- base_model: "PreTrainedModel" = lora_model.get_base_model()
+ print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
+ peft_model = get_peft_model(model, lora_config)
+ loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model
- setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
- setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
- lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors)
+ setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
+ setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
+ peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
+ print("Adapter weights saved in {}".format(loftq_dir))
# Save base model
- unwrap_model(base_model)
- base_model.save_pretrained(save_dir, safe_serialization=save_safetensors)
- tokenizer.save_pretrained(save_dir)
+ base_model: "PreTrainedModel" = peft_model.unload()
+ base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
+ tokenizer.save_pretrained(output_dir)
+ print("Model weights saved in {}".format(output_dir))
+
+ print("- Fine-tune this model with:")
+ print("model_name_or_path: {}".format(output_dir))
+ print("adapter_name_or_path: {}".format(loftq_dir))
+ print("finetuning_type: lora")
+ print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__":
diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py
new file mode 100644
index 00000000..ad9d161c
--- /dev/null
+++ b/scripts/pissa_init.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is based on the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import TYPE_CHECKING
+
+import fire
+from peft import LoraConfig, TaskType, get_peft_model
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+
+def quantize_pissa(
+ model_name_or_path: str,
+ output_dir: str,
+ pissa_iter: int = 4,
+ lora_alpha: int = None,
+ lora_rank: int = 16,
+ lora_dropout: float = 0,
+ lora_target: tuple = ("q_proj", "v_proj"),
+ save_safetensors: bool = True,
+):
+ r"""
+ Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
+ Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
+ """
+ if isinstance(lora_target, str):
+ lora_target = [name.strip() for name in lora_target.split(",")]
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
+
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=lora_rank,
+ lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
+ lora_dropout=lora_dropout,
+ target_modules=lora_target,
+ init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
+ )
+
+ # Init PiSSA model
+ peft_model = get_peft_model(model, lora_config)
+ pissa_dir = os.path.join(output_dir, "pissa_init")
+
+ # Save PiSSA model
+ setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
+ peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
+ print("Adapter weights saved in {}".format(pissa_dir))
+
+ # Save base model
+ base_model: "PreTrainedModel" = peft_model.unload()
+ base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
+ tokenizer.save_pretrained(output_dir)
+ print("Model weights saved in {}".format(output_dir))
+
+ print("- Fine-tune this model with:")
+ print("model_name_or_path: {}".format(output_dir))
+ print("adapter_name_or_path: {}".format(pissa_dir))
+ print("finetuning_type: lora")
+ print("pissa_init: false")
+ print("pissa_convert: true")
+ print("- and optionally with:")
+ print("quantization_bit: 4")
+
+
+if __name__ == "__main__":
+ fire.Fire(quantize_pissa)
diff --git a/scripts/test_toolcall.py b/scripts/test_toolcall.py
index 7e460017..6f6fd06c 100644
--- a/scripts/test_toolcall.py
+++ b/scripts/test_toolcall.py
@@ -1,3 +1,18 @@
+# coding=utf-8
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import Sequence
diff --git a/setup.py b/setup.py
index 405ac46e..d43c311c 100644
--- a/setup.py
+++ b/setup.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import re
@@ -23,14 +37,16 @@ extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
- "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"],
+ "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
- "vllm": ["vllm>=0.4.3"],
- "galore": ["galore-torch"],
- "badam": ["badam"],
- "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
+ "hqq": ["hqq"],
+ "eetq": ["eetq"],
+ "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
+ "vllm": ["vllm>=0.4.3"],
+ "galore": ["galore-torch"],
+ "badam": ["badam>=1.2.1"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"dev": ["ruff", "pytest"],
diff --git a/src/api.py b/src/api.py
index 3655e393..0f925497 100644
--- a/src/api.py
+++ b/src/api.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import uvicorn
diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py
index 78230937..9d732777 100644
--- a/src/llamafactory/__init__.py
+++ b/src/llamafactory/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION
diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py
index 21edab2f..c1264617 100644
--- a/src/llamafactory/api/app.py
+++ b/src/llamafactory/api/app.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from contextlib import asynccontextmanager
from typing import Optional
diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py
index 98957bc1..72b2ae50 100644
--- a/src/llamafactory/api/chat.py
+++ b/src/llamafactory/api/chat.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import base64
import io
import json
@@ -78,9 +92,11 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
- name = message.tool_calls[0].function.name
- arguments = message.tool_calls[0].function.arguments
- content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
+ tool_calls = [
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
+ for tool_call in message.tool_calls
+ ]
+ content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
@@ -104,7 +120,7 @@ def _process_request(
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
- except Exception:
+ except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
@@ -146,15 +162,17 @@ async def create_chat_completion_response(
choices = []
for i, response in enumerate(responses):
if tools:
- result = chat_model.engine.template.format_tools.extract(response.response_text)
+ result = chat_model.engine.template.extract_tool(response.response_text)
else:
result = response.response_text
- if isinstance(result, tuple):
- name, arguments = result
- function = Function(name=name, arguments=arguments)
- tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
- response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
+ if isinstance(result, list):
+ tool_calls = []
+ for tool in result:
+ function = Function(name=tool[0], arguments=tool[1])
+ tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
+
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py
index 5ad9a071..d1ac94de 100644
--- a/src/llamafactory/api/common.py
+++ b/src/llamafactory/api/common.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
from typing import TYPE_CHECKING, Any, Dict
diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py
index 055fa781..a69132ea 100644
--- a/src/llamafactory/api/protocol.py
+++ b/src/llamafactory/api/protocol.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
diff --git a/src/llamafactory/chat/__init__.py b/src/llamafactory/chat/__init__.py
index a1a79de6..07276d48 100644
--- a/src/llamafactory/chat/__init__.py
+++ b/src/llamafactory/chat/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .base_engine import BaseEngine
from .chat_model import ChatModel
diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py
index 65b6c59c..ccdf4c92 100644
--- a/src/llamafactory/chat/base_engine.py
+++ b/src/llamafactory/chat/base_engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments",
) -> None: ...
- @abstractmethod
- async def start(
- self,
- ) -> None: ...
-
@abstractmethod
async def chat(
self,
diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py
index 281ef0c1..5c83fa67 100644
--- a/src/llamafactory/chat/chat_model.py
+++ b/src/llamafactory/chat/chat_model.py
@@ -1,3 +1,20 @@
+# Copyright 2024 THUDM and the LlamaFactory team.
+#
+# This code is inspired by the THUDM's ChatGLM implementation.
+# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response
-def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
+def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
@@ -32,7 +49,6 @@ class ChatModel:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
- asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat(
self,
diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py
index 28e6a409..22a24339 100644
--- a/src/llamafactory/chat/hf_engine.py
+++ b/src/llamafactory/chat/hf_engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import asyncio
import concurrent.futures
import os
@@ -40,11 +54,19 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
- self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
+ try:
+ asyncio.get_event_loop()
+ except RuntimeError:
+ logger.warning("There is no current event loop, creating a new one.")
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
@@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
return scores
- async def start(self) -> None:
- self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
-
async def chat(
self,
messages: Sequence[Dict[str, str]],
@@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
image,
input_kwargs,
)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
@@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
- async with self._semaphore:
+ async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py
index 87ce8684..f0d23676 100644
--- a/src/llamafactory/chat/vllm_engine.py
+++ b/src/llamafactory/chat/vllm_engine.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
-from ..extras.packages import is_vllm_available
+from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
@@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
- from vllm.sequence import MultiModalData
+
+ if is_vllm_version_greater_than_0_5():
+ from vllm.multimodal.image import ImagePixelData
+ else:
+ from vllm.sequence import MultiModalData
if TYPE_CHECKING:
@@ -41,14 +59,14 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
- self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
- "dtype": model_args.vllm_dtype,
+ "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
@@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
- multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
+ if is_vllm_version_greater_than_0_5():
+ multi_modal_data = ImagePixelData(image=pixel_values)
+ else: # TODO: remove vllm 0.4.3 support
+ multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
multi_modal_data = None
@@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
)
return result_generator
- async def start(self) -> None:
- pass
-
async def chat(
self,
messages: Sequence[Dict[str, str]],
diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py
index 5042e53c..48eb2898 100644
--- a/src/llamafactory/cli.py
+++ b/src/llamafactory/cli.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import random
import subprocess
@@ -60,7 +74,7 @@ class Command(str, Enum):
def main():
- command = sys.argv.pop(1)
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
@@ -77,7 +91,7 @@ def main():
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
- subprocess.run(
+ process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
@@ -92,6 +106,7 @@ def main():
),
shell=True,
)
+ sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py
index b08691d3..307853bc 100644
--- a/src/llamafactory/data/__init__.py
+++ b/src/llamafactory/data/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset
diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py
index 434956af..299bdca3 100644
--- a/src/llamafactory/data/aligner.py
+++ b/src/llamafactory/data/aligner.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -10,6 +24,7 @@ from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
+ from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .parser import DatasetAttr
@@ -175,7 +190,10 @@ def convert_sharegpt(
def align_dataset(
- dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
+ dataset: Union["Dataset", "IterableDataset"],
+ dataset_attr: "DatasetAttr",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
@@ -208,7 +226,7 @@ def align_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index 1dc8dd8d..e4859ff5 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Any, Dict, Sequence
diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py
index 9b313112..76ded47e 100644
--- a/src/llamafactory/data/data_utils.py
+++ b/src/llamafactory/data/data_utils.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from enum import Enum, unique
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -16,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
+SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
+
+
@unique
class Role(str, Enum):
USER = "user"
@@ -25,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
-def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
- max_target_len = int(max_len * (target_len / (source_len + target_len)))
- max_target_len = max(max_target_len, reserved_label_len)
- max_source_len = max_len - min(max_target_len, target_len)
- return max_source_len, max_target_len
-
-
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py
index 0cd3d6c1..c1653a76 100644
--- a/src/llamafactory/data/formatter.py
+++ b/src/llamafactory/data/formatter.py
@@ -1,83 +1,36 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
-
-SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
-
-
-JSON_FORMAT_PROMPT = (
- """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
-)
-
-
-TOOL_SYSTEM_PROMPT = (
- "You have access to the following tools:\n{tool_text}"
- "Use the following format if using a tool:\n"
- "```\n"
- "Action: tool name (one of [{tool_names}]).\n"
- "Action Input: the input to the tool{format_prompt}.\n"
- "```\n"
-)
-
-
-def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
- tool_text = ""
- tool_names = []
- for tool in tools:
- param_text = ""
- for name, param in tool["parameters"]["properties"].items():
- required = ", required" if name in tool["parameters"].get("required", []) else ""
- enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
- items = (
- ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
- )
- param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
- name=name,
- type=param.get("type", ""),
- required=required,
- desc=param.get("description", ""),
- enum=enum,
- items=items,
- )
-
- tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
- name=tool["name"], desc=tool.get("description", ""), args=param_text
- )
- tool_names.append(tool["name"])
-
- return TOOL_SYSTEM_PROMPT.format(
- tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
- )
-
-
-def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
- regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
- action_match = re.search(regex, content)
- if not action_match:
- return content
-
- tool_name = action_match.group(1).strip()
- tool_input = action_match.group(2).strip().strip('"').strip("```")
- try:
- arguments = json.loads(tool_input)
- except json.JSONDecodeError:
- return content
-
- return tool_name, json.dumps(arguments, ensure_ascii=False)
+from .data_utils import SLOTS
+from .tool_utils import DefaultToolUtils, GLM4ToolUtils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
- tool_format: Optional[Literal["default"]] = None
+ tool_format: Optional[Literal["default", "glm4"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
- def extract(self, content: str) -> Union[str, Tuple[str, str]]:
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
raise NotImplementedError
@@ -128,34 +81,37 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
- has_name, has_args = False, False
- for slot in filter(lambda s: isinstance(s, str), self.slots):
- if "{{name}}" in slot:
- has_name = True
- if "{{arguments}}" in slot:
- has_args = True
-
- if not has_name or not has_args:
- raise ValueError("Name and arguments placeholders are required in the function formatter.")
+ if self.tool_format == "default":
+ self.slots = DefaultToolUtils.get_function_slots() + self.slots
+ elif self.tool_format == "glm4":
+ self.slots = GLM4ToolUtils.get_function_slots() + self.slots
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
+ functions: List[Tuple[str, str]] = []
try:
- function = json.loads(content)
- name = function["name"]
- arguments = json.dumps(function["arguments"], ensure_ascii=False)
- except Exception:
- name, arguments = "", ""
+ tool_calls = json.loads(content)
+ if not isinstance(tool_calls, list): # parallel function call
+ tool_calls = [tool_calls]
+
+ for tool_call in tool_calls:
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+
+ except json.JSONDecodeError:
+ functions = []
elements = []
- for slot in self.slots:
- if isinstance(slot, str):
- slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
- elements.append(slot)
- elif isinstance(slot, (dict, set)):
- elements.append(slot)
- else:
- raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
+ for name, arguments in functions:
+ for slot in self.slots:
+ if isinstance(slot, str):
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@@ -163,25 +119,22 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
- if self.tool_format is None:
- raise ValueError("Tool format was not found.")
+ if self.tool_format == "default":
+ self._tool_formatter = DefaultToolUtils.tool_formatter
+ self._tool_extractor = DefaultToolUtils.tool_extractor
+ elif self.tool_format == "glm4":
+ self._tool_formatter = GLM4ToolUtils.tool_formatter
+ self._tool_extractor = GLM4ToolUtils.tool_extractor
+ else:
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
- if not len(tools):
- return [""]
-
- if self.tool_format == "default":
- return [default_tool_formatter(tools)]
- else:
- raise NotImplementedError
- except Exception:
+ return [self._tool_formatter(tools) if len(tools) != 0 else ""]
+ except json.JSONDecodeError:
return [""]
- def extract(self, content: str) -> Union[str, Tuple[str, str]]:
- if self.tool_format == "default":
- return default_tool_extractor(content)
- else:
- raise NotImplementedError
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ return self._tool_extractor(content)
diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py
index 2c236c76..8e7062db 100644
--- a/src/llamafactory/data/loader.py
+++ b/src/llamafactory/data/loader.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import inspect
import os
import sys
@@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
- from transformers import ProcessorMixin, Seq2SeqTrainingArguments
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
@@ -32,6 +45,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
@@ -123,7 +137,7 @@ def load_single_dataset(
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
- return align_dataset(dataset, dataset_attr, data_args)
+ return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset(
@@ -134,7 +148,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
- template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
+ template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
@@ -157,7 +171,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
- all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
+ all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
+
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
@@ -169,7 +184,7 @@ def get_dataset(
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=(not data_args.overwrite_cache),
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)
diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py
index ec97bfc1..4bebcd68 100644
--- a/src/llamafactory/data/parser.py
+++ b/src/llamafactory/data/parser.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from dataclasses import dataclass
diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py
index cf207d7e..3a80900c 100644
--- a/src/llamafactory/data/preprocess.py
+++ b/src/llamafactory/data/preprocess.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
@@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING:
- from transformers import ProcessorMixin, Seq2SeqTrainingArguments
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .template import Template
diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py
index 98d83658..7ba05e23 100644
--- a/src/llamafactory/data/processors/feedback.py
+++ b/src/llamafactory/data/processors/feedback.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -42,12 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
- prompt_ids, response_ids = template.encode_oneturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
- _, kl_response_ids = template.encode_oneturn(
- tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
+ _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
@@ -57,6 +66,12 @@ def _encode_feedback_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+ # do not consider the kl_response
+ source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
+ prompt_ids = prompt_ids[:source_len]
+ response_ids = response_ids[:target_len]
+ kl_response_ids = kl_response_ids[:target_len]
+
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py
index fe984efa..c6001e6e 100644
--- a/src/llamafactory/data/processors/pairwise.py
+++ b/src/llamafactory/data/processors/pairwise.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -31,12 +44,8 @@ def _encode_pairwise_example(
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
- prompt_ids, chosen_ids = template.encode_oneturn(
- tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
- _, rejected_ids = template.encode_oneturn(
- tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
+ _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
@@ -46,6 +55,13 @@ def _encode_pairwise_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
+ source_len, target_len = infer_seqlen(
+ len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
+ ) # consider the response is more important
+ prompt_ids = prompt_ids[:source_len]
+ chosen_ids = chosen_ids[:target_len]
+ rejected_ids = rejected_ids[:target_len]
+
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py
index 87727b55..67d6009b 100644
--- a/src/llamafactory/data/processors/pretrain.py
+++ b/src/llamafactory/data/processors/pretrain.py
@@ -1,9 +1,26 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer
from ...hparams import DataArguments
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
- text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
+ eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
+ text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py
index 9903a053..455908ae 100644
--- a/src/llamafactory/data/processors/processor_utils.py
+++ b/src/llamafactory/data/processors/processor_utils.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import bisect
-from typing import TYPE_CHECKING, List, Sequence
+from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available
@@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
+
+
+def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
+ if target_len * 2 < cutoff_len: # truncate source
+ max_target_len = cutoff_len
+ elif source_len * 2 < cutoff_len: # truncate target
+ max_target_len = cutoff_len - source_len
+ else: # truncate both
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
+
+ new_target_len = min(max_target_len, target_len)
+ new_source_len = max(cutoff_len - new_target_len, 0)
+ return new_source_len, new_target_len
diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py
index 35640174..8ef55321 100644
--- a/src/llamafactory/data/processors/supervised.py
+++ b/src/llamafactory/data/processors/supervised.py
@@ -1,14 +1,27 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -38,10 +51,17 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
- encoded_pairs = template.encode_multiturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
+ total_length = 1 if template.efficient_eos else 0
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
+ if total_length >= data_args.cutoff_len:
+ break
+
+ source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
+ source_ids = source_ids[:source_len]
+ target_ids = target_ids[:target_len]
+ total_length += source_len + target_len
+
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py
index f711eeac..b3fc85c9 100644
--- a/src/llamafactory/data/processors/unsupervised.py
+++ b/src/llamafactory/data/processors/unsupervised.py
@@ -1,13 +1,26 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
-from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
+from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
- from transformers import ProcessorMixin
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@@ -34,9 +47,7 @@ def _encode_unsupervised_example(
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
- input_ids, labels = template.encode_oneturn(
- tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
- )
+ input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
@@ -44,6 +55,9 @@ def _encode_unsupervised_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
+ source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
+ input_ids = input_ids[:source_len]
+ labels = labels[:target_len]
return input_ids, labels
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index b600c567..aefd5195 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -1,8 +1,22 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
-from .data_utils import Role, infer_max_len
+from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
@@ -24,69 +38,74 @@ class Template:
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
+ format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
- force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- cutoff_len: int = 1_000_000,
- reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
- encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
- for query_ids, resp_ids in encoded_pairs[:-1]:
- prompt_ids += query_ids + resp_ids
- prompt_ids = prompt_ids + encoded_pairs[-1][0]
- answer_ids = encoded_pairs[-1][1]
+ for encoded_ids in encoded_messages[:-1]:
+ prompt_ids += encoded_ids
+
+ answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- cutoff_len: int = 1_000_000,
- reserved_label_len: int = 1,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
- return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
+ r"""
+ Extracts tool message.
+ """
+ return self.format_tools.extract(content)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: system + query resp
- Turn t: sep + query resp
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
- if i == 0 and (system or tools or self.force_system):
- tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
- elements += self.format_system.apply(content=(system + tool_text))
- elif i > 0 and i % 2 == 0:
+
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ elements += self.format_system.apply(content=(system + tool_text))
+
+ if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -102,11 +121,9 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
- return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+ return encoded_messages
- def _convert_elements_to_ids(
- self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
- ) -> List[int]:
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
@@ -127,57 +144,34 @@ class Template:
return token_ids
- def _make_pairs(
- self,
- encoded_messages: Sequence[List[int]],
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
- encoded_pairs = []
- total_length = 0
- for i in range(0, len(encoded_messages), 2):
- if total_length >= cutoff_len:
- break
-
- max_source_len, max_target_len = infer_max_len(
- source_len=len(encoded_messages[i]),
- target_len=len(encoded_messages[i + 1]),
- max_len=(cutoff_len - total_length),
- reserved_label_len=reserved_label_len,
- )
- source_ids = encoded_messages[i][:max_source_len]
- target_ids = encoded_messages[i + 1][:max_target_len]
- total_length += len(source_ids) + len(target_ids)
- encoded_pairs.append((source_ids, target_ids))
-
- return encoded_pairs
-
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- messages: List[Dict[str, str]],
+ messages: Sequence[Dict[str, str]],
system: str,
tools: str,
- cutoff_len: int,
- reserved_label_len: int,
- ) -> Sequence[Tuple[List[int], List[int]]]:
+ ) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
- Turn 0: system + query resp
- Turn t: sep + query resp
+ Turn 0: prefix + system + query resp
+ Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
+
system_text = ""
- if i == 0 and (system or tools or self.force_system):
- tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
- system_text = self.format_system.apply(content=(system + tool_text))[0]
- elif i > 0 and i % 2 == 0:
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system or tools:
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
+
+ if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@@ -193,7 +187,7 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
- return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
+ return encoded_messages
TEMPLATES: Dict[str, Template] = {}
@@ -208,12 +202,12 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
+ format_prefix: Optional["Formatter"] = None,
default_system: str = "",
- stop_words: List[str] = [],
+ stop_words: Sequence[str] = [],
image_token: str = "",
efficient_eos: bool = False,
replace_eos: bool = False,
- force_system: bool = False,
) -> None:
r"""
Registers a chat template.
@@ -245,9 +239,10 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
- default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
+ default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
+ default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
@@ -256,12 +251,12 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
+ format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
- force_system=force_system,
)
@@ -307,6 +302,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = ""
+ prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@@ -315,11 +314,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
- if isinstance(template, Llama2Template):
- pass
- elif template.force_system:
- jinja_template += "{{ " + system_message + " }}"
- else:
+ if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
@@ -346,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
+ tool_format: Optional[str] = None,
) -> Template:
if name is None:
template = TEMPLATES["empty"] # placeholder
@@ -354,6 +350,12 @@ def get_template_and_fix_tokenizer(
if template is None:
raise ValueError("Template {} does not exist.".format(name))
+ if tool_format is not None:
+ logger.info("Using tool format: {}.".format(tool_format))
+ eos_slots = [] if template.efficient_eos else [{"eos_token"}]
+ template.format_tools = ToolFormatter(tool_format=tool_format)
+ template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
+
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
@@ -435,9 +437,8 @@ _register_template(
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -450,11 +451,7 @@ _register_template(
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- default_system=(
- "You are a helpful AI assistant built by MediaTek Research. "
- "The user you are helping speaks Traditional Chinese and comes from Taiwan."
- ),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
@@ -462,10 +459,9 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
- force_system=True,
)
@@ -473,32 +469,13 @@ _register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
- stop_words=["<|user|>", "<|observation|>"],
- efficient_eos=True,
- force_system=True,
-)
-
-
-_register_template(
- name="chatglm3_system",
- format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
- format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_system=StringFormatter(
- slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
- ),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
- format_observation=StringFormatter(
- slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
- ),
- default_system=(
- "You are ChatGLM3, a large language model trained by Zhipu.AI. "
- "Follow the user's instructions carefully. Respond using markdown."
- ),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
@@ -529,8 +506,7 @@ _register_template(
_register_template(
name="codegeex2",
- format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
@@ -544,21 +520,15 @@ _register_template(
)
]
),
- format_system=StringFormatter(
- slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
- ),
- default_system=(
- "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
- "by providing thorough responses. You are trained by Cohere."
- ),
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -591,30 +561,28 @@ _register_template(
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
- format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
- format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
- stop_words=["<|EOT|>"],
- efficient_eos=True,
)
_register_template(
name="default",
- format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
+ format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
@@ -622,11 +590,7 @@ _register_template(
_register_template(
name="empty",
- format_user=StringFormatter(slots=["{{content}}"]),
- format_assistant=StringFormatter(slots=["{{content}}"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True,
- force_system=True,
)
@@ -648,13 +612,12 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["tool\n{{content}}\nmodel\n"]
),
format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
- force_system=True,
)
@@ -662,36 +625,33 @@ _register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
- format_system=StringFormatter(slots=["[gMASK]{{content}}"]),
- format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
- force_system=True,
)
_register_template(
name="intern",
- format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]),
- format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]),
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
+ format_separator=EmptyFormatter(slots=["\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=[""],
- efficient_eos=True,
+ efficient_eos=True, # internlm tokenizer cannot set eos_token_id
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
- format_separator=EmptyFormatter(slots=["\n"]),
- default_system=(
- "You are an AI assistant whose name is InternLM (书生·浦语).\n"
- "- InternLM (书生·浦语) is a conversational language model that is developed "
- "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
- "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
- "by the user such as English and 中文."
- ),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
@@ -700,7 +660,6 @@ _register_template(
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
- format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]),
)
@@ -723,9 +682,7 @@ _register_template(
)
]
),
- format_system=StringFormatter(
- slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
- ),
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
@@ -734,7 +691,7 @@ _register_template(
)
]
),
- default_system="You are a helpful assistant.",
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
@@ -743,24 +700,21 @@ _register_template(
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
- format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@@ -774,27 +728,25 @@ _register_template(
)
]
),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
- force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
- format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
- force_system=True,
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
- format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
- default_system="You are a helpful AI assistant.",
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
@@ -827,7 +779,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
- force_system=True,
)
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
new file mode 100644
index 00000000..ac5565d5
--- /dev/null
+++ b/src/llamafactory/data/tool_utils.py
@@ -0,0 +1,140 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Any, Dict, List, Tuple, Union
+
+from .data_utils import SLOTS
+
+
+DEFAULT_TOOL_PROMPT = (
+ "You have access to the following tools:\n{tool_text}"
+ "Use the following format if using a tool:\n"
+ "```\n"
+ "Action: tool name (one of [{tool_names}]).\n"
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
+ """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
+ "```\n"
+)
+
+
+GLM4_TOOL_PROMPT = (
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
+)
+
+
+@dataclass
+class ToolUtils(ABC):
+ @staticmethod
+ @abstractmethod
+ def get_function_slots() -> SLOTS: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
+
+ @staticmethod
+ @abstractmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
+
+
+class DefaultToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ tool_names = []
+ for tool in tools:
+ param_text = ""
+ for name, param in tool["parameters"]["properties"].items():
+ required, enum, items = "", "", ""
+ if name in tool["parameters"].get("required", []):
+ required = ", required"
+
+ if param.get("enum", None):
+ enum = ", should be one of [{}]".format(", ".join(param["enum"]))
+
+ if param.get("items", None):
+ items = ", where each item should be {}".format(param["items"].get("type", ""))
+
+ param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
+ name=name,
+ type=param.get("type", ""),
+ required=required,
+ desc=param.get("description", ""),
+ enum=enum,
+ items=items,
+ )
+
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
+ )
+ tool_names.append(tool["name"])
+
+ return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
+ action_match: List[Tuple[str, str]] = re.findall(regex, content)
+ if not action_match:
+ return content
+
+ results = []
+ for match in action_match:
+ tool_name = match[0].strip()
+ tool_input = match[1].strip().strip('"').strip("```")
+ try:
+ arguments = json.loads(tool_input)
+ results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
+ except json.JSONDecodeError:
+ return content
+
+ return results
+
+
+class GLM4ToolUtils(ToolUtils):
+ @staticmethod
+ def get_function_slots() -> SLOTS:
+ return ["{{name}}\n{{arguments}}"]
+
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
+ )
+
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
+ if "\n" not in content:
+ return content
+
+ tool_name, tool_input = content.split("\n", maxsplit=1)
+ try:
+ arguments = json.loads(tool_input)
+ except json.JSONDecodeError:
+ return content
+
+ return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py
index 192f4815..d3140793 100644
--- a/src/llamafactory/eval/evaluator.py
+++ b/src/llamafactory/eval/evaluator.py
@@ -1,4 +1,41 @@
-# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the Dan's test library.
+# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2020 Dan Hendrycks
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
import inspect
import json
@@ -26,9 +63,7 @@ class Evaluator:
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
- self.choice_inputs = [
- self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
- ]
+ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py
index a4a6ef0e..7d524e7c 100644
--- a/src/llamafactory/eval/template.py
+++ b/src/llamafactory/eval/template.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
@@ -10,7 +24,6 @@ class EvalTemplate:
system: str
choice: str
answer: str
- prefix: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
@@ -42,8 +55,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {}
-def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
- eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
+def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
+ eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
def get_eval_template(name: str) -> "EvalTemplate":
@@ -56,8 +69,7 @@ _register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
- answer="\nAnswer: ",
- prefix=" ",
+ answer="\nAnswer:",
)
@@ -66,5 +78,4 @@ _register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
- prefix=" ",
)
diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py
index 466b1269..6029d84f 100644
--- a/src/llamafactory/extras/constants.py
+++ b/src/llamafactory/extras/constants.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
@@ -404,6 +418,18 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
+ "DeepSeek-MoE-Coder-16B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
+ },
+ "DeepSeek-MoE-Coder-236B-Base": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
+ },
+ "DeepSeek-MoE-Coder-16B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
+ },
+ "DeepSeek-MoE-Coder-236B-Chat": {
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
+ },
},
template="deepseek",
)
@@ -496,6 +522,18 @@ register_model_group(
"Gemma-1.1-7B-Chat": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
+ "Gemma-2-9B": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b",
+ },
+ "Gemma-2-27B": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b",
+ },
+ "Gemma-2-9B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-9b-it",
+ },
+ "Gemma-2-27B-Chat": {
+ DownloadSource.DEFAULT: "google/gemma-2-27b-it",
+ },
},
template="gemma",
)
@@ -568,7 +606,7 @@ register_model_group(
register_model_group(
models={
- "Jambda-v0.1": {
+ "Jamba-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
@@ -683,6 +721,21 @@ register_model_group(
)
+register_model_group(
+ models={
+ "MiniCPM-2B-SFT-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
+ },
+ "MiniCPM-2B-DPO-Chat": {
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
+ DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
+ },
+ },
+ template="cpm",
+)
+
+
register_model_group(
models={
"Mistral-7B-v0.1": {
diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py
index 1d4e43f1..14876048 100644
--- a/src/llamafactory/extras/env.py
+++ b/src/llamafactory/extras/env.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import platform
import accelerate
@@ -9,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
-VERSION = "0.8.1.dev0"
+VERSION = "0.8.3.dev0"
def print_env() -> None:
diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py
index 430b8a48..67622212 100644
--- a/src/llamafactory/extras/logging.py
+++ b/src/llamafactory/extras/logging.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
import os
import sys
diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py
index fc33f77e..20c752c5 100644
--- a/src/llamafactory/extras/misc.py
+++ b/src/llamafactory/extras/misc.py
@@ -1,13 +1,29 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's PEFT library.
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import gc
import os
-from typing import TYPE_CHECKING, Dict, Tuple
+from typing import TYPE_CHECKING, Tuple
import torch
-from peft import PeftModel
-from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
+import transformers.dynamic_module_utils
+from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
+from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import (
- SAFE_WEIGHTS_NAME,
- WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
@@ -16,7 +32,6 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
-from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
@@ -28,8 +43,6 @@ except Exception:
if TYPE_CHECKING:
- from trl import AutoModelForCausalLMWithValueHead
-
from ..hparams import ModelArguments
@@ -58,6 +71,9 @@ class AverageMeter:
def check_dependencies() -> None:
+ r"""
+ Checks the version of the required packages.
+ """
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
@@ -68,7 +84,7 @@ def check_dependencies() -> None:
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
-def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
+def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
@@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
- # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
@@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
-def fix_valuehead_checkpoint(
- model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
-) -> None:
- r"""
- The model is already unwrapped.
-
- There are three cases:
- 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
- 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
- 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
-
- We assume `stage3_gather_16bit_weights_on_model_save=true`.
- """
- if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
- return
-
- if safe_serialization:
- from safetensors import safe_open
- from safetensors.torch import save_file
-
- path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
- with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
- state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
- else:
- path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
- state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
-
- decoder_state_dict = {}
- v_head_state_dict = {}
- for name, param in state_dict.items():
- if name.startswith("v_head."):
- v_head_state_dict[name] = param
- else:
- decoder_state_dict[name.replace("pretrained_model.", "")] = param
-
- os.remove(path_to_checkpoint)
- model.pretrained_model.save_pretrained(
- output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
- )
-
- if safe_serialization:
- save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
- else:
- torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
-
- logger.info("Value head model saved at: {}".format(output_dir))
-
-
-def get_current_device() -> torch.device:
+def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
@@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
-def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
+def has_tokenized_data(path: "os.PathLike") -> bool:
+ r"""
+ Checks if the path has a tokenized dataset.
+ """
+ return os.path.isdir(path) and len(os.listdir(path)) > 0
+
+
+def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
@@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
-def has_tokenized_data(path: os.PathLike) -> bool:
- r"""
- Checks if the path has a tokenized dataset.
- """
- return os.path.isdir(path) and len(os.listdir(path)) > 0
+def skip_check_imports() -> None:
+ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
+ transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py
index 4c9e6492..0a84a293 100644
--- a/src/llamafactory/extras/packages.py
+++ b/src/llamafactory/extras/packages.py
@@ -1,5 +1,23 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import importlib.metadata
import importlib.util
+from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
@@ -24,10 +42,6 @@ def is_fastapi_available():
return _is_package_available("fastapi")
-def is_flash_attn2_available():
- return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
-
-
def is_galore_available():
return _is_package_available("galore_torch")
@@ -36,18 +50,10 @@ def is_gradio_available():
return _is_package_available("gradio")
-def is_jieba_available():
- return _is_package_available("jieba")
-
-
def is_matplotlib_available():
return _is_package_available("matplotlib")
-def is_nltk_available():
- return _is_package_available("nltk")
-
-
def is_pillow_available():
return _is_package_available("PIL")
@@ -60,10 +66,6 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
-def is_sdpa_available():
- return _get_package_version("torch") > version.parse("2.1.1")
-
-
def is_starlette_available():
return _is_package_available("sse_starlette")
@@ -74,3 +76,8 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
+
+
+@lru_cache
+def is_vllm_version_greater_than_0_5():
+ return _get_package_version("vllm") >= version.parse("0.5.0")
diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py
index dea23bbe..596d55e7 100644
--- a/src/llamafactory/extras/ploting.py
+++ b/src/llamafactory/extras/ploting.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import math
import os
diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py
index d1ee98dd..cfe448c1 100644
--- a/src/llamafactory/hparams/__init__.py
+++ b/src/llamafactory/hparams/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py
index d2d53ec8..e351fccf 100644
--- a/src/llamafactory/hparams/data_args.py
+++ b/src/llamafactory/hparams/data_args.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
from typing import Literal, Optional
@@ -28,10 +45,6 @@ class DataArguments:
default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
- reserved_label_len: int = field(
- default=1,
- metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
- )
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."},
@@ -90,15 +103,16 @@ class DataArguments:
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
},
)
+ tool_format: Optional[str] = field(
+ default=None,
+ metadata={"help": "Tool format to use for constructing function calling examples."},
+ )
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},
)
def __post_init__(self):
- if self.reserved_label_len >= self.cutoff_len:
- raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
-
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py
index 5a05f6f6..a7f221ca 100644
--- a/src/llamafactory/hparams/evaluation_args.py
+++ b/src/llamafactory/hparams/evaluation_args.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py
index 08af31e4..3867c0ec 100644
--- a/src/llamafactory/hparams/finetuning_args.py
+++ b/src/llamafactory/hparams/finetuning_args.py
@@ -1,5 +1,19 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
-from typing import Literal, Optional
+from typing import List, Literal, Optional
@dataclass
@@ -94,6 +108,18 @@ class LoraArguments:
default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
)
+ pissa_init: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to initialize a PiSSA adapter."},
+ )
+ pissa_iter: int = field(
+ default=16,
+ metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
+ )
+ pissa_convert: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
+ )
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
@@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
return [item.strip() for item in arg.split(",")]
return arg
- self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
- self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
- self.lora_alpha = self.lora_alpha or self.lora_rank * 2
- self.lora_target = split_arg(self.lora_target)
- self.additional_target = split_arg(self.additional_target)
- self.galore_target = split_arg(self.galore_target)
+ self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
+ self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
+ self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
+ self.lora_target: List[str] = split_arg(self.lora_target)
+ self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
+ self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
+ self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
- self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
-
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")
@@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
+ if self.pissa_init and self.finetuning_type != "lora":
+ raise ValueError("`pissa_init` is only valid for LoRA training.")
+
+ if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
+ raise ValueError("Cannot use PiSSA for current training stage.")
+
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py
index 0ee17d1a..7ebb4eed 100644
--- a/src/llamafactory/hparams/generating_args.py
+++ b/src/llamafactory/hparams/generating_args.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py
index 6352a420..087c8c38 100644
--- a/src/llamafactory/hparams/model_args.py
+++ b/src/llamafactory/hparams/model_args.py
@@ -1,5 +1,28 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import asdict, dataclass, field
-from typing import Any, Dict, Literal, Optional
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
+
+from typing_extensions import Self
+
+
+if TYPE_CHECKING:
+ import torch
@dataclass
@@ -22,6 +45,10 @@ class ModelArguments:
)
},
)
+ adapter_folder: Optional[str] = field(
+ default=None,
+ metadata={"help": "The folder containing the adapter weights to load."},
+ )
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
@@ -50,6 +77,10 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
+ quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
+ default="bitsandbytes",
+ metadata={"help": "Quantization method to use for on-the-fly quantization."},
+ )
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
@@ -70,7 +101,7 @@ class ModelArguments:
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
- flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
+ flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
@@ -127,13 +158,9 @@ class ModelArguments:
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
- default=8,
+ default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
- vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
- default="auto",
- metadata={"help": "Data type for model weights and activations in the vLLM engine."},
- )
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
@@ -142,6 +169,10 @@ class ModelArguments:
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
+ infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
+ default="auto",
+ metadata={"help": "Data type for model weights and activations at inference."},
+ )
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
@@ -192,9 +223,9 @@ class ModelArguments:
)
def __post_init__(self):
- self.compute_dtype = None
- self.device_map = None
- self.model_max_length = None
+ self.compute_dtype: Optional["torch.dtype"] = None
+ self.device_map: Optional[Union[str, Dict[str, Any]]] = None
+ self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
@@ -208,11 +239,18 @@ class ModelArguments:
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
- assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
- assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
-
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
+
+ @classmethod
+ def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
+ arg_dict = old_arg.to_dict()
+ arg_dict.update(**kwargs)
+ new_arg = cls(**arg_dict)
+ new_arg.compute_dtype = old_arg.compute_dtype
+ new_arg.device_map = old_arg.device_map
+ new_arg.model_max_length = old_arg.model_max_length
+ return new_arg
diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py
index ff1fbf5d..8b2ea4c1 100644
--- a/src/llamafactory/hparams/parser.py
+++ b/src/llamafactory/hparams/parser.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import logging
import os
import sys
@@ -8,6 +25,7 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
+from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
@@ -65,13 +83,13 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
- if model_args.use_unsloth and is_deepspeed_zero3_enabled():
- raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
-
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
+ if finetuning_args.pissa_init:
+ raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
+
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
@@ -100,7 +118,7 @@ def _check_extra_dependencies(
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
- require_version("badam", "To fix: pip install badam")
+ require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
@@ -162,6 +180,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
+ if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
+ raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
+
+ if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
+
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
@@ -171,32 +195,31 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
- if finetuning_args.use_dora and model_args.use_unsloth:
- raise ValueError("Unsloth does not support DoRA.")
+ if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
+ raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
- if training_args.fp16 or training_args.bf16:
- raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
+ if is_deepspeed_zero3_enabled():
+ raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
- and training_args.parallel_mode.value == "distributed"
+ and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
- if (
- finetuning_args.use_badam
- and finetuning_args.badam_mode == "layer"
- and training_args.parallel_mode.value == "distributed"
- ):
- raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
+ if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
+ if finetuning_args.badam_mode == "ratio":
+ raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
+ elif not is_deepspeed_zero3_enabled():
+ raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
- if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
- raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
+ if finetuning_args.use_galore and training_args.deepspeed is not None:
+ raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
@@ -204,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
+ if model_args.use_unsloth and is_deepspeed_zero3_enabled():
+ raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
+
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
@@ -233,7 +259,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Post-process training arguments
if (
- training_args.parallel_mode.value == "distributed"
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
@@ -293,7 +319,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args.local_rank,
training_args.device,
training_args.n_gpu,
- training_args.parallel_mode.value == "distributed",
+ training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
)
@@ -332,6 +358,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
+ model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py
index de154db9..65e0b68f 100644
--- a/src/llamafactory/launcher.py
+++ b/src/llamafactory/launcher.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from llamafactory.train.tuner import run_exp
diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py
index 9d23d59f..48cfe76c 100644
--- a/src/llamafactory/model/__init__.py
+++ b/src/llamafactory/model/__init__.py
@@ -1,9 +1,25 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
+from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params
__all__ = [
+ "QuantizationMethod",
"load_config",
"load_model",
"load_tokenizer",
diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py
index f4e501a7..7caef9cc 100644
--- a/src/llamafactory/model/adapter.py
+++ b/src/llamafactory/model/adapter.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import re
from typing import TYPE_CHECKING
@@ -25,8 +39,12 @@ def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
+ if not is_trainable:
+ return
+
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
+ is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
+ if not is_trainable:
+ return
+
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
- logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
+ if is_trainable:
+ logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
+
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
else:
adapter_to_merge = model_args.adapter_name_or_path
+ init_kwargs = {
+ "subfolder": model_args.adapter_folder,
+ "offload_folder": model_args.offload_folder,
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "token": model_args.hf_hub_token,
+ }
+
for adapter in adapter_to_merge:
- model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
+ model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
- model = PeftModel.from_pretrained(
- model,
- adapter_to_resume,
- is_trainable=is_trainable,
- offload_folder=model_args.offload_folder,
- )
+ model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
+
+ logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -209,16 +238,24 @@ def _setup_lora_tuning(
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
+ "use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
+ if finetuning_args.pissa_init:
+ if finetuning_args.pissa_iter == -1:
+ logger.info("Using PiSSA initialization.")
+ peft_kwargs["init_lora_weights"] = "pissa"
+ else:
+ logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
+ peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
+
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
- use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
- if model_args.adapter_name_or_path is not None:
- logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
-
return model
@@ -247,29 +281,36 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
- if (not is_trainable) and model_args.adapter_name_or_path is None:
- logger.info("Adapter is not found at evaluation, load the base model.")
- return model
+ if is_trainable and getattr(model, "quantization_method", None) is not None:
+ if finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantized models can only be used for the LoRA tuning.")
- if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
- raise ValueError("You can only use lora for quantized models.")
+ if finetuning_args.pissa_init:
+ raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
- if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
- logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
- cast_trainable_params_to_fp32 = False
+ # cast trainable parameters to float32 if:
+ # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
+ # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
+ cast_trainable_params_to_fp32 = False
+ if not is_trainable:
+ pass
+ elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
+ logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
+ elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
+ logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
- if is_trainable and finetuning_args.finetuning_type == "full":
- _setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
-
- if is_trainable and finetuning_args.finetuning_type == "freeze":
- _setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
-
- if finetuning_args.finetuning_type == "lora":
+ if finetuning_args.finetuning_type == "full":
+ _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "freeze":
+ _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
+ elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
+ else:
+ raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
return model
diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py
index 026a09be..43e65d52 100644
--- a/src/llamafactory/model/loader.py
+++ b/src/llamafactory/model/loader.py
@@ -1,10 +1,25 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
+import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
-from ..extras.misc import count_parameters, try_download_model_from_ms
+from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
@@ -33,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args.
"""
+ skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
@@ -162,17 +178,21 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
+ for param in model.parameters():
+ if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
+ param.data = param.data.to(model_args.compute_dtype)
+
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
- param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
+ param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
- param_stats = "all params: {:d}".format(all_param)
+ param_stats = "all params: {:,}".format(all_param)
logger.info(param_stats)
diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py
index b52ddc86..4bed7e21 100644
--- a/src/llamafactory/model/model_utils/attention.py
+++ b/src/llamafactory/model/model_utils/attention.py
@@ -1,7 +1,22 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING
+from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
+
from ...extras.logging import get_logger
-from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
@@ -13,21 +28,33 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
-def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
+def configure_attn_implementation(
+ config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
+) -> None:
+ if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
+ if model_args.flash_attn == "auto":
+ logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
+ model_args.flash_attn = "disabled"
+ elif model_args.flash_attn != "disabled":
+ logger.warning(
+ "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
+ "Will proceed at your own risk.".format(model_args.flash_attn)
+ )
+
if model_args.flash_attn == "auto":
return
- elif model_args.flash_attn == "off":
+ elif model_args.flash_attn == "disabled":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
- if not is_sdpa_available():
+ if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
- if not is_flash_attn2_available():
+ if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
return
diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py
index e0657be8..f4f3d8a5 100644
--- a/src/llamafactory/model/model_utils/checkpointing.py
+++ b/src/llamafactory/model/model_utils/checkpointing.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and PEFT library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
+# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import inspect
from functools import partial
from types import MethodType
@@ -60,15 +78,12 @@ def _fp32_forward_post_hook(
return output.to(torch.float32)
-def prepare_model_for_training(
- model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
-) -> None:
+def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
- Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
@@ -87,8 +102,8 @@ def prepare_model_for_training(
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
- if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
- logger.info("Upcasting lm_head outputs in float32.")
- output_layer = getattr(model, output_layer_name)
+ if model_args.upcast_lmhead_output:
+ output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
+ logger.info("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)
diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py
index 3d9278e3..3ff79828 100644
--- a/src/llamafactory/model/model_utils/embedding.py
+++ b/src/llamafactory/model/model_utils/embedding.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py
index c8dc52f5..af30bd50 100644
--- a/src/llamafactory/model/model_utils/longlora.py
+++ b/src/llamafactory/model/model_utils/longlora.py
@@ -1,3 +1,22 @@
+# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
+#
+# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
+# This code is also inspired by the original LongLoRA implementation.
+# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from typing import TYPE_CHECKING, Optional, Tuple
@@ -96,7 +115,8 @@ def llama_attention_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -181,11 +201,9 @@ def llama_flash_attention_2_forward(
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
- else:
- groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward(
- query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
+ query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@@ -194,7 +212,8 @@ def llama_flash_attention_2_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -293,7 +312,8 @@ def llama_sdpa_attention_forward(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
- )
+ ),
+ dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -303,7 +323,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
- require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
+ require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py
index 4851bd29..a2812228 100644
--- a/src/llamafactory/model/model_utils/misc.py
+++ b/src/llamafactory/model/model_utils/misc.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
diff --git a/src/llamafactory/model/model_utils/mod.py b/src/llamafactory/model/model_utils/mod.py
index 5708a1a8..ec73af00 100644
--- a/src/llamafactory/model/model_utils/mod.py
+++ b/src/llamafactory/model/model_utils/mod.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py
index e554e45a..5c7473aa 100644
--- a/src/llamafactory/model/model_utils/moe.py
+++ b/src/llamafactory/model/model_utils/moe.py
@@ -1,5 +1,20 @@
-from typing import TYPE_CHECKING
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Sequence
+
+import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
@@ -10,6 +25,13 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
+def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
+ require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
+ from deepspeed.utils import set_z3_leaf_modules # type: ignore
+
+ set_z3_leaf_modules(model, leaf_modules)
+
+
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
@@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled():
return
- require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
- from deepspeed.utils import set_z3_leaf_modules # type: ignore
-
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
- set_z3_leaf_modules(model, [DbrxFFN])
+ _set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
- set_z3_leaf_modules(model, [JambaSparseMoeBlock])
+ _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
- set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
+ _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
- set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
+ _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
- set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
+ _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py
index 02a54f07..317646e0 100644
--- a/src/llamafactory/model/model_utils/quantization.py
+++ b/src/llamafactory/model/model_utils/quantization.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers and Optimum library.
+# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
+# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import random
from enum import Enum, unique
@@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
-from transformers import BitsAndBytesConfig, GPTQConfig
+from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
@@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq"
-def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
+def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
- Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
- TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
+ Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
@@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
data_path = model_args.export_quantization_dataset
data_files = None
- dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
- maxlen = model_args.export_quantization_maxlen
+ dataset = load_dataset(
+ path=data_path,
+ data_files=data_files,
+ split="train",
+ cache_dir=model_args.cache_dir,
+ token=model_args.hf_hub_token,
+ )
samples = []
+ maxlen = model_args.export_quantization_maxlen
for _ in range(model_args.export_quantization_nsamples):
+ n_try = 0
while True:
+ if n_try > 100:
+ raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
+
sample_idx = random.randint(0, len(dataset) - 1)
- sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
- if sample["input_ids"].size(1) >= maxlen:
+ sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
+ n_try += 1
+ if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
- samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
+ attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
+ samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
return samples
@@ -76,14 +105,14 @@ def configure_quantization(
init_kwargs: Dict[str, Any],
) -> None:
r"""
- Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
+ Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq
- if is_deepspeed_zero3_enabled():
- raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
+ if model_args.quantization_bit is not None:
+ logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
- if model_args.quantization_device_map != "auto":
- init_kwargs["device_map"] = {"": get_current_device()}
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
@@ -105,46 +134,72 @@ def configure_quantization(
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
- require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
+ if model_args.export_quantization_bit not in [8, 4, 3, 2]:
+ raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
+
+ require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
- raise ValueError("ChatGLM model is not supported.")
+ raise ValueError("ChatGLM model is not supported yet.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
- tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
- logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
+ logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
- elif model_args.quantization_bit is not None: # bnb
- if model_args.quantization_bit == 8:
- require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
- init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+ elif model_args.quantization_bit is not None: # on-the-fly
+ if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ if model_args.quantization_bit == 8:
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
+ elif model_args.quantization_bit == 4:
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
+ init_kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
+ bnb_4bit_use_double_quant=model_args.double_quantization,
+ bnb_4bit_quant_type=model_args.quantization_type,
+ bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
+ )
+ else:
+ raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
- elif model_args.quantization_bit == 4:
- require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
- init_kwargs["quantization_config"] = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_compute_dtype=model_args.compute_dtype,
- bnb_4bit_use_double_quant=model_args.double_quantization,
- bnb_4bit_quant_type=model_args.quantization_type,
- bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
- )
+ # Do not assign device map if:
+ # 1. deepspeed zero3 or fsdp (train)
+ # 2. auto quantization device map (inference)
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
+ if model_args.quantization_bit != 4:
+ raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
- if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
- if model_args.quantization_bit != 4:
- raise ValueError("Only 4-bit quantized model can use auto device map.")
+ require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
+ else:
+ init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
- require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
- require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
- require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
- init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
- else:
- init_kwargs["device_map"] = {"": get_current_device()}
+ logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.HQQ.value:
+ if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
+ raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
- logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("hqq", "To fix: pip install hqq")
+ init_kwargs["quantization_config"] = HqqConfig(
+ nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
+ ) # use ATEN kernel (axis=0) for performance
+ logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
+ elif model_args.quantization_method == QuantizationMethod.EETQ.value:
+ if model_args.quantization_bit != 8:
+ raise ValueError("EETQ only accepts 8-bit quantization.")
+
+ if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
+ raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
+
+ require_version("eetq", "To fix: pip install eetq")
+ init_kwargs["quantization_config"] = EetqConfig()
+ logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py
index 93ab8929..4373ee19 100644
--- a/src/llamafactory/model/model_utils/rope.py
+++ b/src/llamafactory/model/model_utils/rope.py
@@ -1,3 +1,21 @@
+# Copyright 2024 LMSYS and the LlamaFactory team.
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# This code is inspired by the LMSYS's FastChat library.
+# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
from typing import TYPE_CHECKING
@@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning("Current model does not support RoPE scaling.")
return
- if is_trainable:
- if model_args.rope_scaling == "dynamic":
+ if model_args.model_max_length is not None:
+ if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py
index 8a16409d..9cfaec61 100644
--- a/src/llamafactory/model/model_utils/unsloth.py
+++ b/src/llamafactory/model/model_utils/unsloth.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py
index 64333688..9ab3d45a 100644
--- a/src/llamafactory/model/model_utils/valuehead.py
+++ b/src/llamafactory/model/model_utils/valuehead.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
import torch
diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py
index c8260b7f..700bf470 100644
--- a/src/llamafactory/model/model_utils/visual.py
+++ b/src/llamafactory/model/model_utils/visual.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's Transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Tuple
import torch
diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py
index 47591de6..f1831ced 100644
--- a/src/llamafactory/model/patcher.py
+++ b/src/llamafactory/model/patcher.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
@@ -46,13 +60,16 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
- model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
+ if model_args.infer_dtype != "auto" and not is_trainable:
+ model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
+ else:
+ model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
- configure_attn_implementation(config, model_args)
+ configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
@@ -74,14 +91,17 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
- if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
+ # cast data type of the model if:
+ # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
+ # 2. quantization_bit is not None (qlora)
+ if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
- if init_kwargs["device_map"] == "auto":
+ if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if finetune_args.stage == "sft" and data_args.efficient_packing:
@@ -137,6 +157,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
+ def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
+ if isinstance(self.pretrained_model, PreTrainedModel):
+ return self.pretrained_model.get_output_embeddings()
+
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
@@ -145,4 +169,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
+ setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
diff --git a/src/llamafactory/extras/callbacks.py b/src/llamafactory/train/callbacks.py
similarity index 56%
rename from src/llamafactory/extras/callbacks.py
rename to src/llamafactory/train/callbacks.py
index 441ebbfd..4d024278 100644
--- a/src/llamafactory/extras/callbacks.py
+++ b/src/llamafactory/train/callbacks.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import logging
import os
@@ -8,22 +22,78 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
+import torch
import transformers
-from transformers import TrainerCallback
+from peft import PeftModel
+from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
+from transformers.utils import (
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ is_safetensors_available,
+)
-from .constants import TRAINER_LOG
-from .logging import LoggerHandler, get_logger
-from .misc import fix_valuehead_checkpoint
+from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
+from ..extras.logging import LoggerHandler, get_logger
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.torch import save_file
+
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
+ from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
+def fix_valuehead_checkpoint(
+ model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
+) -> None:
+ r"""
+ The model is already unwrapped.
+
+ There are three cases:
+ 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
+ 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
+ 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
+
+ We assume `stage3_gather_16bit_weights_on_model_save=true`.
+ """
+ if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
+ return
+
+ if safe_serialization:
+ path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
+ with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
+ state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
+ else:
+ path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
+ state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
+
+ decoder_state_dict = {}
+ v_head_state_dict = {}
+ for name, param in state_dict.items():
+ if name.startswith("v_head."):
+ v_head_state_dict[name] = param
+ else:
+ decoder_state_dict[name.replace("pretrained_model.", "")] = param
+
+ os.remove(path_to_checkpoint)
+ model.pretrained_model.save_pretrained(
+ output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
+ )
+
+ if safe_serialization:
+ save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
+ else:
+ torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
+
+ logger.info("Value head model saved at: {}".format(output_dir))
+
+
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
@@ -37,8 +107,70 @@ class FixValueHeadModelCallback(TrainerCallback):
)
+class SaveProcessorCallback(TrainerCallback):
+ def __init__(self, processor: "ProcessorMixin") -> None:
+ r"""
+ Initializes a callback for saving the processor.
+ """
+ self.processor = processor
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
+
+
+class PissaConvertCallback(TrainerCallback):
+ r"""
+ Initializes a callback for converting the PiSSA adapter to a normal one.
+ """
+
+ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the beginning of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
+ r"""
+ Event called at the end of training.
+ """
+ if args.should_save:
+ model = kwargs.pop("model")
+ pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
+ pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
+ pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
+ logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
+ # 1. save a pissa backup with init_lora_weights: True
+ # 2. save a converted lora with init_lora_weights: pissa
+ # 3. load the pissa backup with init_lora_weights: True
+ # 4. delete the initial adapter and change init_lora_weights to pissa
+ if isinstance(model, PeftModel):
+ init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
+ setattr(model.peft_config["default"], "init_lora_weights", True)
+ model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+ model.save_pretrained(
+ pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
+ )
+ model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
+ model.set_adapter("default")
+ model.delete_adapter("pissa_init")
+ setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
+
+
class LogCallback(TrainerCallback):
- def __init__(self, output_dir: str) -> None:
+ def __init__(self) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
@@ -56,7 +188,7 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
- self.logger_handler = LoggerHandler(output_dir)
+ self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
diff --git a/src/llamafactory/train/dpo/__init__.py b/src/llamafactory/train/dpo/__init__.py
index 43fe9420..9ce0d089 100644
--- a/src/llamafactory/train/dpo/__init__.py
+++ b/src/llamafactory/train/dpo/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_dpo
diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py
index d860b29a..e45467d6 100644
--- a/src/llamafactory/train/dpo/trainer.py
+++ b/src/llamafactory/train/dpo/trainer.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@@ -10,7 +28,8 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
-from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -35,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
- self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@@ -61,6 +79,8 @@ class CustomDPOTrainer(DPOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@@ -71,10 +91,17 @@ class CustomDPOTrainer(DPOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if finetuning_args.pissa_convert:
+ self.callback_handler.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -87,12 +114,6 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
@@ -176,7 +197,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None:
ref_model = model
- ref_context = get_ref_context(self.accelerator, model)
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py
index 992985b0..431b5285 100644
--- a/src/llamafactory/train/dpo/workflow.py
+++ b/src/llamafactory/train/dpo/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
diff --git a/src/llamafactory/train/kto/__init__.py b/src/llamafactory/train/kto/__init__.py
index 34c7905a..a1900368 100644
--- a/src/llamafactory/train/kto/__init__.py
+++ b/src/llamafactory/train/kto/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_kto
diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py
index 22a84e4a..460311e4 100644
--- a/src/llamafactory/train/kto/trainer.py
+++ b/src/llamafactory/train/kto/trainer.py
@@ -1,3 +1,21 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@@ -9,7 +27,8 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
-from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
+from ..callbacks import SaveProcessorCallback
+from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@@ -35,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
- self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@@ -60,6 +78,8 @@ class CustomKTOTrainer(KTOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
+
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@@ -70,10 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -92,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
"""
return Trainer._get_train_sampler(self)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -143,7 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
- ref_context = get_ref_context(self.accelerator, model)
+ ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py
index c79b160b..8182a184 100644
--- a/src/llamafactory/train/kto/workflow.py
+++ b/src/llamafactory/train/kto/workflow.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
diff --git a/src/llamafactory/train/ppo/__init__.py b/src/llamafactory/train/ppo/__init__.py
index d17336d5..161f6f5d 100644
--- a/src/llamafactory/train/ppo/__init__.py
+++ b/src/llamafactory/train/ppo/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_ppo
diff --git a/src/llamafactory/train/ppo/ppo_utils.py b/src/llamafactory/train/ppo/ppo_utils.py
index fec3fc1e..05c40946 100644
--- a/src/llamafactory/train/ppo/ppo_utils.py
+++ b/src/llamafactory/train/ppo/ppo_utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py
index 2e1288e4..57f0b848 100644
--- a/src/llamafactory/train/ppo/trainer.py
+++ b/src/llamafactory/train/ppo/trainer.py
@@ -1,6 +1,24 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import os
import sys
+import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -9,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
+from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
@@ -16,9 +35,9 @@ from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
-from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
+from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -81,10 +100,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
# Add deepspeed config
- ppo_config.accelerator_kwargs["kwargs_handlers"] = [
- DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
- ]
if training_args.deepspeed_plugin is not None:
+ ppo_config.accelerator_kwargs["kwargs_handlers"] = [
+ DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
+ ]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler
@@ -113,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.current_device = get_current_device() # patch for deepspeed training
- self.processor = processor
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
@@ -125,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- self.log_callback, self.save_callback = callbacks[0], callbacks[1]
- assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
+ self.callback_handler = CallbackHandler(
+ [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
+ )
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
@@ -134,8 +153,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
- device_type = unwrapped_model.pretrained_model.device.type
- self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
+ self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
+ warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled:
@@ -147,10 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ self.add_callback(FixValueHeadModelCallback)
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
@@ -184,23 +209,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_world_process_zero():
logger.info("***** Running training *****")
- logger.info(" Num examples = {}".format(num_examples))
- logger.info(" Num Epochs = {}".format(num_train_epochs))
- logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
+ logger.info(" Num examples = {:,}".format(num_examples))
+ logger.info(" Num Epochs = {:,}".format(num_train_epochs))
+ logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
logger.info(
- " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
+ " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
)
- logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
- logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
- logger.info(" Total training steps = {}".format(max_steps))
- logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
+ logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
+ logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
+ logger.info(" Total training steps = {:,}".format(max_steps))
+ logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
reward_meter = AverageMeter()
- self.log_callback.on_train_begin(self.args, self.state, self.control)
+ self.callback_handler.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
try:
@@ -238,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.warning("Failed to save stats due to unknown errors.")
self.state.global_step += 1
- self.log_callback.on_step_end(self.args, self.state, self.control)
+ self.callback_handler.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
logs = dict(
@@ -250,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
tqdm.write(str(logs))
logs["step"] = step
self.state.log_history.append(logs)
- self.log_callback.on_log(self.args, self.state, self.control)
+ self.callback_handler.on_log(self.args, self.state, self.control, logs)
loss_meter.reset()
reward_meter.reset()
@@ -258,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
)
- self.save_callback.on_save(
- self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
- )
+ self.callback_handler.on_save(self.args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
- self.log_callback.on_train_end(self.args, self.state, self.control)
- self.save_callback.on_train_end(
- self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
- )
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
def create_optimizer(
self,
@@ -486,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
elif self.args.should_save:
self._save(output_dir)
-
- if self.processor is not None and self.args.should_save:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py
index 111704c6..651296f3 100644
--- a/src/llamafactory/train/ppo/workflow.py
+++ b/src/llamafactory/train/ppo/workflow.py
@@ -1,14 +1,28 @@
-# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's TRL library.
+# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding
from ...data import get_dataset
-from ...extras.callbacks import FixValueHeadModelCallback
-from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
+from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -60,6 +74,7 @@ def run_ppo(
ppo_trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/src/llamafactory/train/pt/__init__.py b/src/llamafactory/train/pt/__init__.py
index bdf397f6..d80e6f22 100644
--- a/src/llamafactory/train/pt/__init__.py
+++ b/src/llamafactory/train/pt/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_pt
diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py
index 1d96e82f..e8f180a6 100644
--- a/src/llamafactory/train/pt/trainer.py
+++ b/src/llamafactory/train/pt/trainer.py
@@ -1,9 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from types import MethodType
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
@@ -27,11 +42,18 @@ class CustomTrainer(Trainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -43,9 +65,3 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
-
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py
index 8a635567..b84a0e7d 100644
--- a/src/llamafactory/train/pt/workflow.py
+++ b/src/llamafactory/train/pt/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import math
from typing import TYPE_CHECKING, List, Optional
diff --git a/src/llamafactory/train/rm/__init__.py b/src/llamafactory/train/rm/__init__.py
index dedac35f..48278315 100644
--- a/src/llamafactory/train/rm/__init__.py
+++ b/src/llamafactory/train/rm/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_rm
diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py
index 99dc6ab8..fb880b1c 100644
--- a/src/llamafactory/train/rm/metric.py
+++ b/src/llamafactory/train/rm/metric.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Dict, Sequence, Tuple, Union
import numpy as np
diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py
index bfb344dc..accc877d 100644
--- a/src/llamafactory/train/rm/trainer.py
+++ b/src/llamafactory/train/rm/trainer.py
@@ -1,3 +1,42 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the CarperAI's trlx library.
+# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2022 CarperAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
import json
import os
from types import MethodType
@@ -7,6 +46,7 @@ import torch
from transformers import Trainer
from ...extras.logging import get_logger
+from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
@@ -30,12 +70,20 @@ class PairwiseTrainer(Trainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
self.can_return_loss = True # override property to return eval_loss
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
+ self.add_callback(FixValueHeadModelCallback)
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -48,12 +96,6 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -63,7 +105,7 @@ class PairwiseTrainer(Trainer):
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
- See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
+ See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
@@ -79,7 +121,6 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
- # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
@@ -125,4 +166,5 @@ class PairwiseTrainer(Trainer):
res: List[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
+
writer.write("\n".join(res))
diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py
index 2e9e194b..e0b32b77 100644
--- a/src/llamafactory/train/rm/workflow.py
+++ b/src/llamafactory/train/rm/workflow.py
@@ -1,12 +1,48 @@
-# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+# Copyright 2024 the LlamaFactory team.
+#
+# This code is inspired by the CarperAI's trlx library.
+# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# MIT License
+#
+# Copyright (c) 2022 CarperAI
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
-from ...extras.callbacks import FixValueHeadModelCallback
-from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
+from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy
from .trainer import PairwiseTrainer
@@ -40,7 +76,7 @@ def run_rm(
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
- callbacks=callbacks + [FixValueHeadModelCallback()],
+ callbacks=callbacks,
compute_metrics=compute_accuracy,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
@@ -52,6 +88,7 @@ def run_rm(
trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
+
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
diff --git a/src/llamafactory/train/sft/__init__.py b/src/llamafactory/train/sft/__init__.py
index f2f84e78..475dfe5f 100644
--- a/src/llamafactory/train/sft/__init__.py
+++ b/src/llamafactory/train/sft/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .workflow import run_sft
diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py
index b135fcfb..c69608c0 100644
--- a/src/llamafactory/train/sft/metric.py
+++ b/src/llamafactory/train/sft/metric.py
@@ -1,14 +1,35 @@
+# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict
import numpy as np
+import torch
+from transformers import EvalPrediction
+from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
-from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available
+from ...extras.packages import is_rouge_available
if TYPE_CHECKING:
- from transformers.tokenization_utils import PreTrainedTokenizer
+ from transformers import PreTrainedTokenizer
if is_jieba_available():
@@ -23,6 +44,22 @@ if is_rouge_available():
from rouge_chinese import Rouge
+def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
+ preds, labels = eval_preds.predictions, eval_preds.label_ids
+ accuracies = []
+ for i in range(len(preds)):
+ pred, label = preds[i, :-1], labels[i, 1:]
+ label_mask = label != IGNORE_INDEX
+ accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
+
+ return {"accuracy": float(np.mean(accuracies))}
+
+
+def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
+ logits = logits[0] if isinstance(logits, (list, tuple)) else logits
+ return torch.argmax(logits, dim=-1)
+
+
@dataclass
class ComputeMetrics:
r"""
@@ -31,11 +68,11 @@ class ComputeMetrics:
tokenizer: "PreTrainedTokenizer"
- def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
+ def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
r"""
Uses the model predictions to compute metrics.
"""
- preds, labels = eval_preds
+ preds, labels = eval_preds.predictions, eval_preds.label_ids
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py
index c063b214..954bb69f 100644
--- a/src/llamafactory/train/sft/trainer.py
+++ b/src/llamafactory/train/sft/trainer.py
@@ -1,3 +1,20 @@
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from types import MethodType
@@ -9,10 +26,12 @@ from transformers import Seq2SeqTrainer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
+from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
+ from torch.utils.data import Dataset
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
@@ -32,11 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
- self.processor = processor
- if finetuning_args.use_badam:
- from badam import clip_grad_norm_for_sparse_tensor
- self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
+ if processor is not None:
+ self.add_callback(SaveProcessorCallback(processor))
+
+ if finetuning_args.pissa_convert:
+ self.add_callback(PissaConvertCallback)
+
+ if finetuning_args.use_badam:
+ from badam import BAdamCallback, clip_grad_norm_old_version
+
+ self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
+ self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@@ -49,12 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
- super()._save(output_dir, state_dict)
- if self.processor is not None:
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- getattr(self.processor, "image_processor").save_pretrained(output_dir)
-
def prediction_step(
self,
model: "torch.nn.Module",
@@ -94,7 +114,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
- def save_predictions(self, predict_results: "PredictionOutput") -> None:
+ def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
@@ -115,18 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
- if len(pad_len):
- preds[i] = np.concatenate(
- (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
- ) # move pad token to last
+ if len(pad_len): # move pad token to last
+ preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
- decoded_labels = self.tokenizer.batch_decode(
- labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
- )
- decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
+ decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
+ decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+ decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
- for label, pred in zip(decoded_labels, decoded_preds):
- res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
+ for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
+ res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
+
writer.write("\n".join(res))
diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py
index f1e000bd..c12a70aa 100644
--- a/src/llamafactory/train/sft/workflow.py
+++ b/src/llamafactory/train/sft/workflow.py
@@ -1,4 +1,19 @@
-# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the HuggingFace's transformers library.
+# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
@@ -10,7 +25,7 @@ from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
-from .metric import ComputeMetrics
+from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
@@ -56,7 +71,8 @@ def run_sft(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
- compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
+ preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)
@@ -75,7 +91,7 @@ def run_sft(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation
if training_args.do_eval:
@@ -92,7 +108,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
- trainer.save_predictions(predict_results)
+ trainer.save_predictions(dataset, predict_results)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py
index 0ddcdb11..4b581691 100644
--- a/src/llamafactory/train/trainer_utils.py
+++ b/src/llamafactory/train/trainer_utils.py
@@ -1,8 +1,27 @@
-from contextlib import contextmanager
+# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
+#
+# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
+# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
+# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
+# and the HuggingFace's TRL library: https://github.com/huggingface/trl
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
+from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
@@ -19,7 +38,6 @@ if is_galore_available():
if TYPE_CHECKING:
- from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
@@ -83,15 +101,12 @@ def create_ref_model(
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
- ref_model_args_dict = model_args.to_dict()
- ref_model_args_dict.update(
- dict(
- model_name_or_path=finetuning_args.ref_model,
- adapter_name_or_path=finetuning_args.ref_model_adapters,
- quantization_bit=finetuning_args.ref_model_quantization_bit,
- )
+ ref_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.ref_model,
+ adapter_name_or_path=finetuning_args.ref_model_adapters,
+ quantization_bit=finetuning_args.ref_model_quantization_bit,
)
- ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
@@ -102,9 +117,11 @@ def create_ref_model(
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
- tokenizer = load_tokenizer(model_args)["tokenizer"]
+ ref_model_args = ModelArguments.copyfrom(model_args)
+ ref_finetuning_args = FinetuningArguments()
+ tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
- tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
+ tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
@@ -139,15 +156,12 @@ def create_reward_model(
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
- reward_model_args_dict = model_args.to_dict()
- reward_model_args_dict.update(
- dict(
- model_name_or_path=finetuning_args.reward_model,
- adapter_name_or_path=finetuning_args.reward_model_adapters,
- quantization_bit=finetuning_args.reward_model_quantization_bit,
- )
+ reward_model_args = ModelArguments.copyfrom(
+ model_args,
+ model_name_or_path=finetuning_args.reward_model,
+ adapter_name_or_path=finetuning_args.reward_model_adapters,
+ quantization_bit=finetuning_args.reward_model_quantization_bit,
)
- reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model(
@@ -158,17 +172,6 @@ def create_reward_model(
return reward_model
-@contextmanager
-def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
- r"""
- Gets adapter context for the reference model.
- """
- with accelerator.unwrap_model(model).disable_adapter():
- model.eval()
- yield
- model.train()
-
-
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
@@ -184,7 +187,7 @@ def _create_galore_optimizer(
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
- galore_targets = find_all_linear_modules(model)
+ galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
galore_targets = finetuning_args.galore_target
@@ -334,6 +337,7 @@ def _create_badam_optimizer(
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
+ ds_zero3_enabled=is_deepspeed_zero3_enabled(),
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
@@ -355,7 +359,7 @@ def _create_badam_optimizer(
**optim_kwargs,
)
logger.info(
- f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
+ f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py
index eed875e9..dc982e07 100644
--- a/src/llamafactory/train/tuner.py
+++ b/src/llamafactory/train/tuner.py
@@ -1,13 +1,30 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
-from ..extras.callbacks import LogCallback
+from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
+from .callbacks import LogCallback
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -24,8 +41,8 @@ logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
+ callbacks.append(LogCallback())
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
- callbacks.append(LogCallback(training_args.output_dir))
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
@@ -84,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
safe_serialization=(not model_args.export_legacy_format),
)
+ if finetuning_args.stage == "rm":
+ if model_args.adapter_name_or_path is not None:
+ vhead_path = model_args.adapter_name_or_path[-1]
+ else:
+ vhead_path = model_args.model_name_or_path
+
+ if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+ elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
+ shutil.copy(
+ os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
+ os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
+ )
+ logger.info("Copied valuehead to {}.".format(model_args.export_dir))
+
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py
index c82710d3..8abef920 100644
--- a/src/llamafactory/webui/chatter.py
+++ b/src/llamafactory/webui/chatter.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
@@ -9,7 +23,7 @@ from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
-from .common import get_save_dir
+from .common import QUANTIZATION_BITS, get_save_dir
from .locales import ALERTS
@@ -62,17 +76,24 @@ class WebChatModel(ChatModel):
yield error
return
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
+ infer_dtype=get("infer.infer_dtype"),
)
if checkpoint_path:
@@ -126,16 +147,15 @@ class WebChatModel(ChatModel):
):
response += new_text
if tools:
- result = self.engine.template.format_tools.extract(response)
+ result = self.engine.template.extract_tool(response)
else:
result = response
- if isinstance(result, tuple):
- name, arguments = result
- arguments = json.loads(arguments)
- tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
- output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
- bot_text = "```json\n" + tool_call + "\n```"
+ if isinstance(result, list):
+ tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
+ tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
+ output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
+ bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py
index 37b38df0..bced18f0 100644
--- a/src/llamafactory/webui/common.py
+++ b/src/llamafactory/webui/common.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from collections import defaultdict
@@ -33,13 +47,19 @@ DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
+QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
+GPTQ_BITS = ["8", "4", "3", "2"]
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
- paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
+ if os.path.sep in paths[-1]:
+ logger.warning("Found complex path, some features may be not available.")
+ return paths[-1]
+
+ paths = (path.replace(" ", "").strip() for path in paths)
return os.path.join(DEFAULT_SAVE_DIR, *paths)
diff --git a/src/llamafactory/webui/components/__init__.py b/src/llamafactory/webui/components/__init__.py
index 5c1e21b8..715fb6e4 100644
--- a/src/llamafactory/webui/components/__init__.py
+++ b/src/llamafactory/webui/components/__init__.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py
index f83694b1..ad74114b 100644
--- a/src/llamafactory/webui/components/chatbot.py
+++ b/src/llamafactory/webui/components/chatbot.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py
index 232b973d..88e500cf 100644
--- a/src/llamafactory/webui/components/data.py
+++ b/src/llamafactory/webui/components/data.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py
index 0a7a0f44..b522913e 100644
--- a/src/llamafactory/webui/components/eval.py
+++ b/src/llamafactory/webui/components/eval.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py
index 7e1493c8..0a938f02 100644
--- a/src/llamafactory/webui/components/export.py
+++ b/src/llamafactory/webui/components/export.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
-from ..common import get_save_dir
+from ..common import GPTQ_BITS, get_save_dir
from ..locales import ALERTS
@@ -18,7 +32,11 @@ if TYPE_CHECKING:
from ..engine import Engine
-GPTQ_BITS = ["8", "4", "3", "2"]
+def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
+ if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
+ return gr.Dropdown(value="none", interactive=False)
+ else:
+ return gr.Dropdown(interactive=True)
def save_model(
@@ -96,6 +114,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
+ checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
+ checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
+
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py
index 970f4629..a0064479 100644
--- a/src/llamafactory/webui/components/infer.py
+++ b/src/llamafactory/webui/components/infer.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
@@ -18,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
- infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ with gr.Row():
+ infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
+ infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
+
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
- input_elems.update({infer_backend})
- elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
+ input_elems.update({infer_backend, infer_dtype})
+ elem_dict.update(
+ dict(
+ infer_backend=infer_backend,
+ infer_dtype=infer_dtype,
+ load_btn=load_btn,
+ unload_btn=unload_btn,
+ info_box=info_box,
+ )
+ )
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)
diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py
index fd0ead3d..9df3f062 100644
--- a/src/llamafactory/webui/components/top.py
+++ b/src/llamafactory/webui/components/top.py
@@ -1,10 +1,24 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
-from ..utils import can_quantize
+from ..utils import can_quantize, can_quantize_to
if is_gradio_available():
@@ -29,17 +43,23 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
- quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
- template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
- rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
- booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
+ quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1)
+ quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
+ template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
+ rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
+ booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
visual_inputs = gr.Checkbox(scale=1)
- model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
+ model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
- finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
+ finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
+ list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
+ )
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
+ quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
return dict(
lang=lang,
@@ -49,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
+ quantization_method=quantization_method,
template=template,
rope_scaling=rope_scaling,
booster=booster,
diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py
index dccc8500..4636050b 100644
--- a/src/llamafactory/webui/components/train.py
+++ b/src/llamafactory/webui/components/train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
@@ -40,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
- compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
+ compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(
@@ -152,10 +166,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox()
with gr.Row():
- with gr.Column(scale=1):
- use_rslora = gr.Checkbox()
- use_dora = gr.Checkbox()
-
+ use_rslora = gr.Checkbox()
+ use_dora = gr.Checkbox()
+ use_pissa = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
@@ -168,6 +181,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter,
use_rslora,
use_dora,
+ use_pissa,
lora_target,
additional_target,
}
@@ -182,6 +196,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter=create_new_adapter,
use_rslora=use_rslora,
use_dora=use_dora,
+ use_pissa=use_pissa,
lora_target=lora_target,
additional_target=additional_target,
)
@@ -279,7 +294,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1):
loss_viewer = gr.Plot()
- input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
+ input_elems.update({output_dir, config_path, ds_stage, ds_offload})
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
diff --git a/src/llamafactory/webui/css.py b/src/llamafactory/webui/css.py
index 36e3d4c2..53982119 100644
--- a/src/llamafactory/webui/css.py
+++ b/src/llamafactory/webui/css.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
CSS = r"""
.duplicate-button {
margin: auto !important;
diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py
index eb6142d3..04893215 100644
--- a/src/llamafactory/webui/engine.py
+++ b/src/llamafactory/webui/engine.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py
index bae3ba76..d25f4d38 100644
--- a/src/llamafactory/webui/interface.py
+++ b/src/llamafactory/webui/interface.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from ..extras.packages import is_gradio_available
diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py
index 05cf3bed..852b1b3c 100644
--- a/src/llamafactory/webui/locales.py
+++ b/src/llamafactory/webui/locales.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
LOCALES = {
"lang": {
"en": {
@@ -71,15 +85,29 @@ LOCALES = {
"quantization_bit": {
"en": {
"label": "Quantization bit",
- "info": "Enable 4/8-bit model quantization (QLoRA).",
+ "info": "Enable quantization (QLoRA).",
},
"ru": {
"label": "Уровень квантования",
- "info": "Включить 4/8-битное квантование модели (QLoRA).",
+ "info": "Включить квантование (QLoRA).",
},
"zh": {
"label": "量化等级",
- "info": "启用 4/8 比特模型量化(QLoRA)。",
+ "info": "启用量化(QLoRA)。",
+ },
+ },
+ "quantization_method": {
+ "en": {
+ "label": "Quantization method",
+ "info": "Quantization algorithm to use.",
+ },
+ "ru": {
+ "label": "Метод квантования",
+ "info": "Алгоритм квантования, который следует использовать.",
+ },
+ "zh": {
+ "label": "量化方法",
+ "info": "使用的量化算法。",
},
},
"template": {
@@ -732,6 +760,20 @@ LOCALES = {
"info": "使用权重分解的 LoRA。",
},
},
+ "use_pissa": {
+ "en": {
+ "label": "Use PiSSA",
+ "info": "Use PiSSA method.",
+ },
+ "ru": {
+ "label": "используйте PiSSA",
+ "info": "Используйте метод PiSSA.",
+ },
+ "zh": {
+ "label": "使用 PiSSA",
+ "info": "使用 PiSSA 方法。",
+ },
+ },
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
@@ -1192,6 +1234,17 @@ LOCALES = {
"label": "推理引擎",
},
},
+ "infer_dtype": {
+ "en": {
+ "label": "Inference data type",
+ },
+ "ru": {
+ "label": "Тип данных для вывода",
+ },
+ "zh": {
+ "label": "推理数据类型",
+ },
+ },
"load_btn": {
"en": {
"value": "Load model",
diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py
index 326fdb8d..ebe9f1b9 100644
--- a/src/llamafactory/webui/manager.py
+++ b/src/llamafactory/webui/manager.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
@@ -57,6 +71,7 @@ class Manager:
self._id_to_elem["top.finetuning_type"],
self._id_to_elem["top.checkpoint_path"],
self._id_to_elem["top.quantization_bit"],
+ self._id_to_elem["top.quantization_method"],
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py
index 852805da..ffec54e2 100644
--- a/src/llamafactory/webui/runner.py
+++ b/src/llamafactory/webui/runner.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
@@ -8,9 +22,9 @@ from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
-from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
+from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
-from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
+from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
@@ -38,7 +52,7 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
if self.trainer is not None:
- abort_leaf_process(self.trainer.pid)
+ abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
@@ -90,6 +104,11 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
@@ -97,7 +116,8 @@ class Runner:
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -160,6 +180,8 @@ class Runner:
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
+ args["pissa_init"] = get("train.use_pissa")
+ args["pissa_convert"] = get("train.use_pissa")
args["lora_target"] = get("train.lora_target") or "all"
args["additional_target"] = get("train.additional_target") or None
@@ -219,13 +241,19 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
+ if get("top.quantization_bit") in QUANTIZATION_BITS:
+ quantization_bit = int(get("top.quantization_bit"))
+ else:
+ quantization_bit = None
+
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
- quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
+ quantization_bit=quantization_bit,
+ quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -283,6 +311,7 @@ class Runner:
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
+ env["LLAMABOARD_WORKDIR"] = args["output_dir"]
if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1"
@@ -291,7 +320,7 @@ class Runner:
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
config_dict = {}
- skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
+ skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:
diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py
index e39f2aa4..6e5fdbe4 100644
--- a/src/llamafactory/webui/utils.py
+++ b/src/llamafactory/webui/utils.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import os
import signal
@@ -11,6 +25,7 @@ from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
+from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
@@ -19,16 +34,19 @@ if is_gradio_available():
import gradio as gr
-def abort_leaf_process(pid: int) -> None:
+def abort_process(pid: int) -> None:
r"""
- Aborts the leaf processes.
+ Aborts the processes recursively in a bottom-up way.
"""
- children = psutil.Process(pid).children()
- if children:
- for child in children:
- abort_leaf_process(child.pid)
- else:
+ try:
+ children = psutil.Process(pid).children()
+ if children:
+ for child in children:
+ abort_process(child.pid)
+
os.kill(pid, signal.SIGABRT)
+ except Exception:
+ pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
@@ -41,6 +59,20 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
return gr.Dropdown(interactive=True)
+def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
+ r"""
+ Returns the available quantization bits.
+ """
+ if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ available_bits = ["none", "8", "4"]
+ elif quantization_method == QuantizationMethod.HQQ.value:
+ available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
+ elif quantization_method == QuantizationMethod.EETQ.value:
+ available_bits = ["none", "8"]
+
+ return gr.Dropdown(choices=available_bits)
+
+
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
diff --git a/src/train.py b/src/train.py
index b20aa9d2..6703ffdb 100644
--- a/src/train.py
+++ b/src/train.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from llamafactory.train.tuner import run_exp
diff --git a/src/webui.py b/src/webui.py
index bbefb54e..99370af2 100644
--- a/src/webui.py
+++ b/src/webui.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from llamafactory.webui.interface import create_ui
diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py
new file mode 100644
index 00000000..1845df24
--- /dev/null
+++ b/tests/data/test_formatter.py
@@ -0,0 +1,123 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
+
+
+def test_empty_formatter():
+ formatter = EmptyFormatter(slots=["\n"])
+ assert formatter.apply() == ["\n"]
+
+
+def test_string_formatter():
+ formatter = StringFormatter(slots=["", "Human: {{content}}\nAssistant:"])
+ assert formatter.apply(content="Hi") == ["", "Human: Hi\nAssistant:"]
+
+
+def test_function_formatter():
+ formatter = FunctionFormatter(slots=[], tool_format="default")
+ tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
+ assert formatter.apply(content=tool_calls) == [
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
+ ]
+
+
+def test_multi_function_formatter():
+ formatter = FunctionFormatter(slots=[], tool_format="default")
+ tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
+ assert formatter.apply(content=tool_calls) == [
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
+ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
+ ]
+
+
+def test_default_tool_formatter():
+ formatter = ToolFormatter(tool_format="default")
+ tools = [
+ {
+ "name": "test_tool",
+ "description": "tool_desc",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "foo": {"type": "string", "description": "foo_desc"},
+ "bar": {"type": "number", "description": "bar_desc"},
+ },
+ "required": ["foo"],
+ },
+ }
+ ]
+ assert formatter.apply(content=json.dumps(tools)) == [
+ "You have access to the following tools:\n"
+ "> Tool Name: test_tool\n"
+ "Tool Description: tool_desc\n"
+ "Tool Args:\n"
+ " - foo (string, required): foo_desc\n"
+ " - bar (number): bar_desc\n\n"
+ "Use the following format if using a tool:\n"
+ "```\n"
+ "Action: tool name (one of [test_tool]).\n"
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
+ """(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
+ "```\n"
+ ]
+
+
+def test_default_tool_extractor():
+ formatter = ToolFormatter(tool_format="default")
+ result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
+ assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
+
+
+def test_default_multi_tool_extractor():
+ formatter = ToolFormatter(tool_format="default")
+ result = (
+ """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
+ """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
+ )
+ assert formatter.extract(result) == [
+ ("test_tool", """{"foo": "bar", "size": 10}"""),
+ ("another_tool", """{"foo": "job", "size": 2}"""),
+ ]
+
+
+def test_glm4_tool_formatter():
+ formatter = ToolFormatter(tool_format="glm4")
+ tools = [
+ {
+ "name": "test_tool",
+ "description": "tool_desc",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "foo": {"type": "string", "description": "foo_desc"},
+ "bar": {"type": "number", "description": "bar_desc"},
+ },
+ "required": ["foo"],
+ },
+ }
+ ]
+ assert formatter.apply(content=json.dumps(tools)) == [
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
+ "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4))
+ ]
+
+
+def test_glm4_tool_extractor():
+ formatter = ToolFormatter(tool_format="glm4")
+ result = """test_tool\n{"foo": "bar", "size": 10}\n"""
+ assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
diff --git a/tests/data/test_processor.py b/tests/data/test_processor.py
new file mode 100644
index 00000000..fa8f7172
--- /dev/null
+++ b/tests/data/test_processor.py
@@ -0,0 +1,32 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple
+
+import pytest
+
+from llamafactory.data.processors.processor_utils import infer_seqlen
+
+
+@pytest.mark.parametrize(
+ "test_input,test_output",
+ [
+ ((3000, 2000, 1000), (600, 400)),
+ ((2000, 3000, 1000), (400, 600)),
+ ((1000, 100, 1000), (900, 100)),
+ ((100, 1000, 1000), (100, 900)),
+ ],
+)
+def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
+ assert test_output == infer_seqlen(*test_input)
diff --git a/tests/data/test_supervised.py b/tests/data/test_supervised.py
index bb7f71df..9cb49615 100644
--- a/tests/data/test_supervised.py
+++ b/tests/data/test_supervised.py
@@ -1,24 +1,40 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
+import random
import pytest
from datasets import load_dataset
+from transformers import AutoTokenizer
from llamafactory.data import get_dataset
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
- "cutoff_len": 1024,
+ "cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
@@ -26,19 +42,26 @@ TRAINING_ARGS = {
}
-@pytest.mark.parametrize("test_num", [5])
-def test_supervised(test_num: int):
- model_args, data_args, training_args, _, _ = get_train_args(TRAINING_ARGS)
+@pytest.mark.parametrize("num_samples", [16])
+def test_supervised(num_samples: int):
+ model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
- original_data = load_dataset(TRAINING_ARGS["dataset"], split="train")
- for test_idx in range(test_num):
- decode_result = tokenizer.decode(tokenized_data["input_ids"][test_idx])
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+
+ original_data = load_dataset(TRAIN_ARGS["dataset"], split="train")
+ indexes = random.choices(range(len(original_data)), k=num_samples)
+ for index in indexes:
+ prompt = original_data[index]["instruction"]
+ if original_data[index]["input"]:
+ prompt += "\n" + original_data[index]["input"]
+
messages = [
- {"role": "user", "content": original_data[test_idx]["instruction"]},
- {"role": "assistant", "content": original_data[test_idx]["output"]},
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": original_data[index]["output"]},
]
- templated_result = tokenizer.apply_chat_template(messages, tokenize=False)
- assert decode_result == templated_result
+ templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
+ decoded_result = tokenizer.decode(tokenized_data["input_ids"][index])
+ assert templated_result == decoded_result
diff --git a/tests/data/test_template.py b/tests/data/test_template.py
new file mode 100644
index 00000000..e4728a84
--- /dev/null
+++ b/tests/data/test_template.py
@@ -0,0 +1,80 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from transformers import AutoTokenizer
+
+from llamafactory.data import get_template_and_fix_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+MESSAGES = [
+ {"role": "user", "content": "How are you"},
+ {"role": "assistant", "content": "I am fine!"},
+ {"role": "user", "content": "你好"},
+ {"role": "assistant", "content": "很高兴认识你!"},
+]
+
+
+def test_encode_oneturn():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(prompt_ids) == (
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
+
+
+def test_encode_multiturn():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(encoded_pairs[0][0]) == (
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
+ assert tokenizer.decode(encoded_pairs[1][0]) == (
+ "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+ assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
+
+
+def test_jinja_template():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ get_template_and_fix_tokenizer(tokenizer, name="llama3")
+ assert tokenizer.chat_template != ref_tokenizer.chat_template
+ assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
+
+
+def test_qwen_template():
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
+ template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
+ prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
+ assert tokenizer.decode(prompt_ids) == (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\nHow are you<|im_end|>\n"
+ "<|im_start|>assistant\nI am fine!<|im_end|>\n"
+ "<|im_start|>user\n你好<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"
diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py
new file mode 100644
index 00000000..f85d9d57
--- /dev/null
+++ b/tests/eval/test_eval_template.py
@@ -0,0 +1,91 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from llamafactory.eval.template import get_eval_template
+
+
+def test_eval_template_en():
+ support_set = [
+ {
+ "question": "Fewshot question",
+ "A": "Fewshot1",
+ "B": "Fewshot2",
+ "C": "Fewshot3",
+ "D": "Fewshot4",
+ "answer": "B",
+ }
+ ]
+ example = {
+ "question": "Target question",
+ "A": "Target1",
+ "B": "Target2",
+ "C": "Target3",
+ "D": "Target4",
+ "answer": "C",
+ }
+ template = get_eval_template(name="en")
+ messages = template.format_example(example, support_set=support_set, subject_name="SubName")
+ assert messages == [
+ {
+ "role": "user",
+ "content": (
+ "The following are multiple choice questions (with answers) about SubName.\n\n"
+ "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
+ ),
+ },
+ {"role": "assistant", "content": "B"},
+ {
+ "role": "user",
+ "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
+ },
+ {"role": "assistant", "content": "C"},
+ ]
+
+
+def test_eval_template_zh():
+ support_set = [
+ {
+ "question": "示例问题",
+ "A": "示例答案1",
+ "B": "示例答案2",
+ "C": "示例答案3",
+ "D": "示例答案4",
+ "answer": "B",
+ }
+ ]
+ example = {
+ "question": "目标问题",
+ "A": "目标答案1",
+ "B": "目标答案2",
+ "C": "目标答案3",
+ "D": "目标答案4",
+ "answer": "C",
+ }
+ template = get_eval_template(name="zh")
+ messages = template.format_example(example, support_set=support_set, subject_name="主题")
+ assert messages == [
+ {
+ "role": "user",
+ "content": (
+ "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
+ "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
+ ),
+ },
+ {"role": "assistant", "content": "B"},
+ {
+ "role": "user",
+ "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
+ },
+ {"role": "assistant", "content": "C"},
+ ]
diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py
index 4d414289..4cae3d7c 100644
--- a/tests/model/model_utils/test_attention.py
+++ b/tests/model/model_utils/test_attention.py
@@ -1,3 +1,17 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
@@ -6,11 +20,16 @@ from llamafactory.hparams import get_infer_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "template": "llama3",
+}
def test_attention():
- attention_available = ["off"]
+ attention_available = ["disabled"]
if is_torch_sdpa_available():
attention_available.append("sdpa")
@@ -18,18 +37,12 @@ def test_attention():
attention_available.append("fa2")
llama_attention_classes = {
- "off": "LlamaAttention",
+ "disabled": "LlamaAttention",
"sdpa": "LlamaSdpaAttention",
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
- model_args, _, finetuning_args, _ = get_infer_args(
- {
- "model_name_or_path": TINY_LLAMA,
- "template": "llama2",
- "flash_attn": requested_attention,
- }
- )
+ model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
for module in model.modules():
diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py
new file mode 100644
index 00000000..9b6dfc9e
--- /dev/null
+++ b/tests/model/model_utils/test_checkpointing.py
@@ -0,0 +1,74 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import torch
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_train_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TRAIN_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "stage": "sft",
+ "do_train": True,
+ "finetuning_type": "lora",
+ "lora_target": "all",
+ "dataset": "llamafactory/tiny-supervised-dataset",
+ "dataset_dir": "ONLINE",
+ "template": "llama3",
+ "cutoff_len": 1024,
+ "overwrite_cache": True,
+ "output_dir": "dummy_dir",
+ "overwrite_output_dir": True,
+ "fp16": True,
+}
+
+
+def test_checkpointing_enable():
+ model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
+ assert getattr(module, "gradient_checkpointing") is True
+
+
+def test_checkpointing_disable():
+ model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
+ assert getattr(module, "gradient_checkpointing") is False
+
+
+def test_upcast_layernorm():
+ model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ for name, param in model.named_parameters():
+ if param.ndim == 1 and "norm" in name:
+ assert param.dtype == torch.float32
+
+
+def test_upcast_lmhead_output():
+ model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+ inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
+ outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
+ assert outputs.dtype == torch.float32
diff --git a/tests/model/test_base.py b/tests/model/test_base.py
new file mode 100644
index 00000000..6431a504
--- /dev/null
+++ b/tests/model/test_base.py
@@ -0,0 +1,80 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Dict
+
+import pytest
+import torch
+from transformers import AutoModelForCausalLM
+from trl import AutoModelForCausalLMWithValueHead
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
+
+
+@pytest.fixture
+def fix_valuehead_cpu_loading():
+ def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
+ state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ AutoModelForCausalLMWithValueHead.post_init = post_init
+
+
+def test_base():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ compare_model(model, ref_model)
+
+
+@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
+def test_valuehead():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(
+ tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
+ )
+
+ ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
+ TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model.v_head = ref_model.v_head.to(torch.float16)
+ compare_model(model, ref_model)
diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py
index c6cdec78..5f478af6 100644
--- a/tests/model/test_freeze.py
+++ b/tests/model/test_freeze.py
@@ -1,19 +1,33 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import torch
-from llamafactory.hparams import get_train_args
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,16 +37,19 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "finetuning_type": "freeze",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_freeze_all_modules():
- model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "freeze_trainable_layers": 1,
- **TRAINING_ARGS,
- }
- )
+
+def test_freeze_train_all_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for name, param in model.named_parameters():
if name.startswith("model.layers.1."):
assert param.requires_grad is True
@@ -42,16 +59,13 @@ def test_freeze_all_modules():
assert param.dtype == torch.float16
-def test_freeze_extra_modules():
+def test_freeze_train_extra_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "freeze_trainable_layers": 1,
- "freeze_extra_modules": "embed_tokens,lm_head",
- **TRAINING_ARGS,
- }
+ {"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for name, param in model.named_parameters():
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
assert param.requires_grad is True
@@ -59,3 +73,13 @@ def test_freeze_extra_modules():
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
+
+
+def test_freeze_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ for param in model.parameters():
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
diff --git a/tests/model/test_full.py b/tests/model/test_full.py
index ef57a980..0a6e0743 100644
--- a/tests/model/test_full.py
+++ b/tests/model/test_full.py
@@ -1,19 +1,33 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import torch
-from llamafactory.hparams import get_train_args
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,11 +37,29 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "finetuning_type": "full",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_full():
- model_args, _, _, finetuning_args, _ = get_train_args(TRAINING_ARGS)
+
+def test_full_train():
+ model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
for param in model.parameters():
assert param.requires_grad is True
assert param.dtype == torch.float32
+
+
+def test_full_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ for param in model.parameters():
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py
index 1f2c02ae..630e5f75 100644
--- a/tests/model/test_lora.py
+++ b/tests/model/test_lora.py
@@ -1,19 +1,43 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
+from typing import Dict, Sequence
+import pytest
import torch
+from peft import LoraModel, PeftModel
+from transformers import AutoModelForCausalLM
+from trl import AutoModelForCausalLMWithValueHead
-from llamafactory.hparams import get_train_args
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
-TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
-TRAINING_ARGS = {
+TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
+
+TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
+
+TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
- "dataset": "llamafactory/tiny_dataset",
+ "dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
@@ -23,16 +47,70 @@ TRAINING_ARGS = {
"fp16": True,
}
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "adapter_name_or_path": TINY_LLAMA_ADAPTER,
+ "finetuning_type": "lora",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
-def test_lora_all_modules():
- model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "lora_target": "all",
- **TRAINING_ARGS,
- }
+
+def load_reference_model(is_trainable: bool = False) -> "LoraModel":
+ model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
+ lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable)
+ for param in filter(lambda p: p.requires_grad, lora_model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ return lora_model
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ if any(key in name for key in diff_keys):
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
+ else:
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
+
+
+@pytest.fixture
+def fix_valuehead_cpu_loading():
+ def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
+ state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
+ self.v_head.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ AutoModelForCausalLMWithValueHead.post_init = post_init
+
+
+def test_lora_train_qv_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ linear_modules = set()
+ for name, param in model.named_parameters():
+ if any(module in name for module in ["lora_A", "lora_B"]):
+ linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
+ assert param.requires_grad is True
+ assert param.dtype == torch.float32
+ else:
+ assert param.requires_grad is False
+ assert param.dtype == torch.float16
+
+ assert linear_modules == {"q_proj", "v_proj"}
+
+
+def test_lora_train_all_modules():
+ model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
linear_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -46,16 +124,13 @@ def test_lora_all_modules():
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
-def test_lora_extra_modules():
+def test_lora_train_extra_modules():
model_args, _, _, finetuning_args, _ = get_train_args(
- {
- "lora_target": "all",
- "additional_target": "embed_tokens,lm_head",
- **TRAINING_ARGS,
- }
+ {"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
extra_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -70,3 +145,54 @@ def test_lora_extra_modules():
assert param.dtype == torch.float16
assert extra_modules == {"embed_tokens", "lm_head"}
+
+
+def test_lora_train_old_adapters():
+ model_args, _, _, finetuning_args, _ = get_train_args(
+ {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS}
+ )
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ ref_model = load_reference_model(is_trainable=True)
+ compare_model(model, ref_model)
+
+
+def test_lora_train_new_adapters():
+ model_args, _, _, finetuning_args, _ = get_train_args(
+ {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS}
+ )
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ ref_model = load_reference_model(is_trainable=True)
+ compare_model(
+ model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
+ )
+
+
+@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
+def test_lora_train_valuehead():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(
+ tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True
+ )
+
+ ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
+ TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ state_dict = model.state_dict()
+ ref_state_dict = ref_model.state_dict()
+
+ assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
+ assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
+
+
+def test_lora_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ ref_model = load_reference_model().merge_and_unload()
+ compare_model(model, ref_model)
diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py
new file mode 100644
index 00000000..030310d0
--- /dev/null
+++ b/tests/model/test_pissa.py
@@ -0,0 +1,90 @@
+# Copyright 2024 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import torch
+from peft import LoraModel, PeftModel
+from transformers import AutoModelForCausalLM
+
+from llamafactory.extras.misc import get_current_device
+from llamafactory.hparams import get_infer_args, get_train_args
+from llamafactory.model import load_model, load_tokenizer
+
+
+TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+
+TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
+
+TRAIN_ARGS = {
+ "model_name_or_path": TINY_LLAMA,
+ "stage": "sft",
+ "do_train": True,
+ "finetuning_type": "lora",
+ "pissa_init": True,
+ "pissa_iter": -1,
+ "dataset": "llamafactory/tiny-supervised-dataset",
+ "dataset_dir": "ONLINE",
+ "template": "llama3",
+ "cutoff_len": 1024,
+ "overwrite_cache": True,
+ "output_dir": "dummy_dir",
+ "overwrite_output_dir": True,
+ "fp16": True,
+}
+
+INFER_ARGS = {
+ "model_name_or_path": TINY_LLAMA_PISSA,
+ "adapter_name_or_path": TINY_LLAMA_PISSA,
+ "adapter_folder": "pissa_init",
+ "finetuning_type": "lora",
+ "template": "llama3",
+ "infer_dtype": "float16",
+}
+
+
+def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
+ state_dict_a = model_a.state_dict()
+ state_dict_b = model_b.state_dict()
+ assert set(state_dict_a.keys()) == set(state_dict_b.keys())
+ for name in state_dict_a.keys():
+ assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
+
+
+def test_pissa_init():
+ model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
+
+ base_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True)
+ for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
+ param.data = param.data.to(torch.float32)
+
+ compare_model(model, ref_model)
+
+
+def test_pissa_inference():
+ model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
+ tokenizer_module = load_tokenizer(model_args)
+ model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
+
+ base_model = AutoModelForCausalLM.from_pretrained(
+ TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
+ )
+ ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init")
+ ref_model = ref_model.merge_and_unload()
+ compare_model(model, ref_model)