diff --git a/.dockerignore b/.dockerignore index 2ac0e11d..23ad75a8 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,10 +4,10 @@ .venv cache data +docker +saves hf_cache output -examples .dockerignore .gitattributes .gitignore -Dockerfile diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 1d962200..768adea6 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -38,7 +38,9 @@ body: 请合理使用 Markdown 标签来格式化您的文本。 placeholder: | + ```bash llamafactory-cli train ... + ``` - type: textarea id: expected-behavior diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b31e9d19..d23d6be3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,3 +5,4 @@ Fixes # (issue) ## Before submitting - [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)? +- [ ] Did you write any new necessary tests? diff --git a/.github/workflows/label_issue.yml b/.github/workflows/label_issue.yml new file mode 100644 index 00000000..ffd644a7 --- /dev/null +++ b/.github/workflows/label_issue.yml @@ -0,0 +1,27 @@ +name: label_issue + +on: + issues: + types: + - opened + +jobs: + label_issue: + runs-on: ubuntu-latest + + steps: + - env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ISSUE_URL: ${{ github.event.issue.html_url }} + ISSUE_TITLE: ${{ github.event.issue.title }} + run: | + LABEL=pending + NPU_KEYWORDS=(npu ascend huawei 华为 昇腾) + ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]') + for KEYWORD in ${NPU_KEYWORDS[@]}; do + if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then + LABEL=pending,npu + break + fi + done + gh issue edit $ISSUE_URL --add-label $LABEL diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..15c7153e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,40 @@ +name: publish + +on: + release: + types: + - published + +jobs: + publish: + name: Upload release to PyPI + + runs-on: ubuntu-latest + + environment: + name: release + url: https://pypi.org/p/llamafactory + + permissions: + id-token: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.8" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build + + - name: Build package + run: | + python -m build + + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 32edf6a8..73d77de5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,21 +19,27 @@ on: jobs: tests: runs-on: ubuntu-latest + steps: - - uses: actions/checkout@v4 + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.8" cache: "pip" cache-dependency-path: "setup.py" + - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install .[torch,dev] + python -m pip install ".[torch,dev]" + - name: Check quality run: | make style && make quality + - name: Test with pytest run: | make test diff --git a/.gitignore b/.gitignore index 0355c666..82e6e9e6 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,8 @@ cython_debug/ .idea/ # custom .gitignore -user.config -saves/ cache/ +config/ +saves/ +output/ +wandb/ diff --git a/CITATION.cff b/CITATION.cff index 4caf3787..01b4c9fd 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -12,12 +12,16 @@ authors: given-names: "Yanhan" - family-names: "Luo" given-names: "Zheyan" +- family-names: "Feng" + given-names: "Zhangchi" - family-names: "Ma" given-names: "Yongqiang" title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models" url: "https://arxiv.org/abs/2403.13372" preferred-citation: - type: article + type: conference-paper + conference: + name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)" authors: - family-names: "Zheng" given-names: "Yaowei" @@ -29,9 +33,12 @@ preferred-citation: given-names: "Yanhan" - family-names: "Luo" given-names: "Zheyan" + - family-names: "Feng" + given-names: "Zhangchi" - family-names: "Ma" given-names: "Yongqiang" - journal: "arXiv preprint arXiv:2403.13372" title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models" url: "https://arxiv.org/abs/2403.13372" year: 2024 + publisher: "Association for Computational Linguistics" + address: "Bangkok, Thailand" diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 0a35e355..00000000 --- a/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:24.01-py3 - -WORKDIR /app - -COPY requirements.txt /app/ -RUN pip install -r requirements.txt - -COPY . /app/ -RUN pip install -e .[metrics,bitsandbytes,qwen] - -VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] -EXPOSE 7860 - -CMD [ "llamafactory-cli", "webui" ] diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..82c51f63 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE requirements.txt diff --git a/Makefile b/Makefile index 65be047b..3f13b215 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ style: ruff format $(check_dirs) test: - pytest tests/ + CUDA_VISIBLE_DEVICES= pytest tests/ diff --git a/README.md b/README.md index fb6c5782..3d3feae5 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-44-green)](#projects-using-llama-factory) +[![Citation](https://img.shields.io/badge/citation-71-green)](#projects-using-llama-factory) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -15,7 +15,7 @@ [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) -👋 Join our [WeChat](assets/wechat.jpg). +👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg). \[ English | [中文](README_zh.md) \] @@ -48,8 +48,8 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. -- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. +- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. +- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models. +[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. -[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models. +[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. @@ -151,35 +151,32 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Supported Models -| Model | Model size | Template | -| -------------------------------------------------------- | -------------------------------- | --------- | -| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | -| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | gemma | -| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | -| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | -| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | -| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| Model | Model size | Template | +| ------------------------------------------------------------ | -------------------------------- | --------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. @@ -259,6 +256,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) +- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) +- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) @@ -335,10 +335,10 @@ huggingface-cli login ```bash git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory -pip install -e '.[torch,metrics]' +pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. @@ -357,9 +357,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
For Ascend NPU users -Join [NPU user group](assets/wechat_npu.jpg). - -To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: +To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: ```bash # replace the url according to your CANN version and devices @@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch-npu | 2.1.0 | 2.1.0.post3 | | deepspeed | 0.13.2 | 0.13.2 | -Docker image: - -- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) -- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) - Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations. +Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) +
### Data Preparation @@ -405,9 +400,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` See [examples/README.md](examples/README.md) for advanced usage (including distributed training). @@ -417,34 +412,89 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) -#### Use local environment - ```bash -CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui +llamafactory-cli webui ``` - +### Build Docker -#### Use Docker +For CUDA users: ```bash -docker build -f ./Dockerfile -t llama-factory:latest . -docker run --gpus=all \ - -v ./hf_cache:/root/.cache/huggingface/ \ +cd docker/docker-cuda/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +For Ascend NPU users: + +```bash +cd docker/docker-npu/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +
Build without Docker Compose + +For CUDA users: + +```bash +docker build -f ./docker/docker-cuda/Dockerfile \ + --build-arg INSTALL_BNB=false \ + --build-arg INSTALL_VLLM=false \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg INSTALL_FLASHATTN=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +docker run -dit --gpus=all \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ -v ./data:/app/data \ -v ./output:/app/output \ -p 7860:7860 \ + -p 8000:8000 \ --shm-size 16G \ - --name llama_factory \ - -d llama-factory:latest + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash ``` -#### Use Docker Compose +For Ascend NPU users: ```bash -docker compose -f ./docker-compose.yml up -d +# Choose docker image upon your environment +docker build -f ./docker/docker-npu/Dockerfile \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +# Change `device` upon your resources +docker run -dit \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ + -v ./data:/app/data \ + -v ./output:/app/output \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -p 7860:7860 \ + -p 8000:8000 \ + --device /dev/davinci0 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + --shm-size 16G \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash ``` +
+
Details about volume - hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. @@ -456,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d ### Deploy with OpenAI-style API and vLLM ```bash -CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` > [!TIP] @@ -474,7 +524,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam ### Use W&B Logger -To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments. +To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. ```yaml report_to: wandb @@ -494,38 +544,63 @@ If you have a project that should be incorporated, please contact via email or c 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) -1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) -1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) +1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) +1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) 1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904) 1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625) 1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176) 1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187) 1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746) 1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801) -1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809) +1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809) 1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819) 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) -1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073) 1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541) 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) -1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) +1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) 1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443) 1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604) 1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827) 1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167) -1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) +1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) 1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084) 1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836) 1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581) 1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215) 1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621) -1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140) -1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140) +1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760) +1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378) +1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055) +1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739) +1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816) +1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215) +1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30) +1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380) +1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106) +1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136) +1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496) +1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688) +1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955) +1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973) +1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115) +1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815) +1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099) +1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173) +1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074) +1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408) +1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546) +1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) +1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) +1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) +1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. @@ -533,6 +608,8 @@ If you have a project that should be incorporated, please contact via email or c 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. +1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. +1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
@@ -540,17 +617,19 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation If this work is helpful, please kindly cite as: ```bibtex -@article{zheng2024llamafactory, +@inproceedings{zheng2024llamafactory, title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, - author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma}, - journal={arXiv preprint arXiv:2403.13372}, + author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma}, + booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, + address={Bangkok, Thailand}, + publisher={Association for Computational Linguistics}, year={2024}, url={http://arxiv.org/abs/2403.13372} } diff --git a/README_zh.md b/README_zh.md index 142254df..cb5a42e4 100644 --- a/README_zh.md +++ b/README_zh.md @@ -4,7 +4,7 @@ [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-44-green)](#使用了-llama-factory-的项目) +[![Citation](https://img.shields.io/badge/citation-71-green)](#使用了-llama-factory-的项目) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -15,7 +15,7 @@ [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) -👋 加入我们的[微信群](assets/wechat.jpg)。 +👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。 \[ [English](README.md) | 中文 \] @@ -48,8 +48,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 -- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 -- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 +- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 +- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 -[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。 +[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 -[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。 +[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 @@ -151,35 +151,32 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 模型 -| 模型名 | 模型大小 | Template | -| -------------------------------------------------------- | -------------------------------- | --------- | -| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | -| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | gemma | -| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | -| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | -| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | -| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| 模型名 | 模型大小 | Template | +| ------------------------------------------------------------ | -------------------------------- | --------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。 @@ -259,6 +256,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) +- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub) +- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) @@ -335,10 +335,10 @@ huggingface-cli login ```bash git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory -pip install -e '.[torch,metrics]' +pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 @@ -357,9 +357,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
昇腾 NPU 用户指南 -加入 [NPU 用户群](assets/wechat_npu.jpg)。 - -在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: +在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: ```bash # 请替换 URL 为 CANN 版本和设备型号对应的 URL @@ -382,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch-npu | 2.1.0 | 2.1.0.post3 | | deepspeed | 0.13.2 | 0.13.2 | -Docker 镜像: - -- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) -- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) - 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。 +下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) +
### 数据准备 @@ -405,9 +400,9 @@ Docker 镜像: 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` 高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。 @@ -417,32 +412,89 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) -#### 使用本地环境 - ```bash -CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui +llamafactory-cli webui ``` -#### 使用 Docker +### 构建 Docker + +CUDA 用户: ```bash -docker build -f ./Dockerfile -t llama-factory:latest . -docker run --gpus=all \ - -v ./hf_cache:/root/.cache/huggingface/ \ +cd docker/docker-cuda/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +昇腾 NPU 用户: + +```bash +cd docker/docker-npu/ +docker-compose up -d +docker-compose exec llamafactory bash +``` + +
不使用 Docker Compose 构建 + +CUDA 用户: + +```bash +docker build -f ./docker/docker-cuda/Dockerfile \ + --build-arg INSTALL_BNB=false \ + --build-arg INSTALL_VLLM=false \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg INSTALL_FLASHATTN=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +docker run -dit --gpus=all \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ -v ./data:/app/data \ -v ./output:/app/output \ -p 7860:7860 \ + -p 8000:8000 \ --shm-size 16G \ - --name llama_factory \ - -d llama-factory:latest + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash ``` -#### 使用 Docker Compose +昇腾 NPU 用户: ```bash -docker compose -f ./docker-compose.yml up -d +# 根据您的环境选择镜像 +docker build -f ./docker/docker-npu/Dockerfile \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +# 根据您的资源更改 `device` +docker run -dit \ + -v ./hf_cache:/root/.cache/huggingface \ + -v ./ms_cache:/root/.cache/modelscope \ + -v ./data:/app/data \ + -v ./output:/app/output \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -p 7860:7860 \ + -p 8000:8000 \ + --device /dev/davinci0 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + --shm-size 16G \ + --name llamafactory \ + llamafactory:latest + +docker exec -it llamafactory bash ``` +
+
数据卷详情 - hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。 @@ -454,7 +506,7 @@ docker compose -f ./docker-compose.yml up -d ### 利用 vLLM 部署 OpenAI API ```bash -CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` > [!TIP] @@ -472,7 +524,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` ### 使用 W&B 面板 -若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。 +若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。 ```yaml report_to: wandb @@ -492,38 +544,63 @@ run_name: test_run # 可选 1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) -1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) -1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) +1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) +1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286) 1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904) 1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625) 1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176) 1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187) 1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746) 1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801) -1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809) +1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809) 1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819) 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) -1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) 1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) 1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) 1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073) 1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541) 1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246) -1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) +1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008) 1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443) 1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604) 1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827) 1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167) -1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) +1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316) 1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084) 1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836) 1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581) 1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215) 1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621) -1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140) -1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140) +1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585) +1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760) +1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378) +1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055) +1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739) +1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816) +1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215) +1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30) +1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380) +1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106) +1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136) +1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496) +1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688) +1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955) +1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973) +1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115) +1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815) +1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099) +1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173) +1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074) +1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408) +1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546) +1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695) +1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233) +1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069) +1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25) 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 @@ -531,6 +608,8 @@ run_name: test_run # 可选 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。 +1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。 +1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: 在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
@@ -538,17 +617,19 @@ run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 如果您觉得此项目有帮助,请考虑以下列格式引用 ```bibtex -@article{zheng2024llamafactory, - title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, - author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma}, - journal={arXiv preprint arXiv:2403.13372}, +@inproceedings{zheng2024llamafactory, + title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models}, + author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma}, + booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, + address={Bangkok, Thailand}, + publisher={Association for Computational Linguistics}, year={2024}, url={http://arxiv.org/abs/2403.13372} } diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 9602a3e3..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,23 +0,0 @@ -version: '3.8' - -services: - llama-factory: - build: - dockerfile: Dockerfile - context: . - container_name: llama_factory - volumes: - - ./hf_cache:/root/.cache/huggingface/ - - ./data:/app/data - - ./output:/app/output - ports: - - "7860:7860" - ipc: host - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: "all" - capabilities: [gpu] - restart: unless-stopped diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile new file mode 100644 index 00000000..d94aa970 --- /dev/null +++ b/docker/docker-cuda/Dockerfile @@ -0,0 +1,58 @@ +# Use the NVIDIA official image with PyTorch 2.3.0 +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html +FROM nvcr.io/nvidia/pytorch:24.02-py3 + +# Define environments +ENV MAX_JOBS=4 +ENV FLASH_ATTENTION_FORCE_BUILD=TRUE + +# Define installation arguments +ARG INSTALL_BNB=false +ARG INSTALL_VLLM=false +ARG INSTALL_DEEPSPEED=false +ARG INSTALL_FLASHATTN=false +ARG PIP_INDEX=https://pypi.org/simple + +# Set the working directory +WORKDIR /app + +# Install the requirements +COPY requirements.txt /app +RUN pip config set global.index-url "$PIP_INDEX" && \ + pip config set global.extra-index-url "$PIP_INDEX" && \ + python -m pip install --upgrade pip && \ + python -m pip install -r requirements.txt + +# Rebuild flash attention +RUN pip uninstall -y transformer-engine flash-attn && \ + if [ "$INSTALL_FLASHATTN" == "true" ]; then \ + pip uninstall -y ninja && pip install ninja && \ + pip install --no-cache-dir flash-attn --no-build-isolation; \ + fi + +# Copy the rest of the application into the image +COPY . /app + +# Install the LLaMA Factory +RUN EXTRA_PACKAGES="metrics"; \ + if [ "$INSTALL_BNB" == "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \ + fi; \ + if [ "$INSTALL_VLLM" == "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \ + fi; \ + if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ + fi; \ + pip install -e ".[$EXTRA_PACKAGES]" + +# Set up volumes +VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] + +# Expose port 7860 for the LLaMA Board +ENV GRADIO_SERVER_PORT 7860 +EXPOSE 7860 + +# Expose port 8000 for the API service +ENV API_PORT 8000 +EXPOSE 8000 diff --git a/docker/docker-cuda/docker-compose.yml b/docker/docker-cuda/docker-compose.yml new file mode 100644 index 00000000..16267dc3 --- /dev/null +++ b/docker/docker-cuda/docker-compose.yml @@ -0,0 +1,32 @@ +services: + llamafactory: + build: + dockerfile: ./docker/docker-cuda/Dockerfile + context: ../.. + args: + INSTALL_BNB: false + INSTALL_VLLM: false + INSTALL_DEEPSPEED: false + INSTALL_FLASHATTN: false + PIP_INDEX: https://pypi.org/simple + container_name: llamafactory + volumes: + - ../../hf_cache:/root/.cache/huggingface + - ../../ms_cache:/root/.cache/modelscope + - ../../data:/app/data + - ../../output:/app/output + ports: + - "7860:7860" + - "8000:8000" + ipc: host + tty: true + stdin_open: true + command: bash + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: "all" + capabilities: [gpu] + restart: unless-stopped diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile new file mode 100644 index 00000000..34cf9616 --- /dev/null +++ b/docker/docker-npu/Dockerfile @@ -0,0 +1,45 @@ +# Use the Ubuntu 22.04 image with CANN 8.0.rc1 +# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags +# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04 +FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04 +# FROM cosdt/cann:8.0.rc1-910-openeuler22.03 +# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03 + +# Define environments +ENV DEBIAN_FRONTEND=noninteractive + +# Define installation arguments +ARG INSTALL_DEEPSPEED=false +ARG PIP_INDEX=https://pypi.org/simple +ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu + +# Set the working directory +WORKDIR /app + +# Install the requirements +COPY requirements.txt /app +RUN pip config set global.index-url "$PIP_INDEX" && \ + pip config set global.extra-index-url "$TORCH_INDEX" && \ + python -m pip install --upgrade pip && \ + python -m pip install -r requirements.txt + +# Copy the rest of the application into the image +COPY . /app + +# Install the LLaMA Factory +RUN EXTRA_PACKAGES="torch-npu,metrics"; \ + if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ + fi; \ + pip install -e ".[$EXTRA_PACKAGES]" + +# Set up volumes +VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] + +# Expose port 7860 for the LLaMA Board +ENV GRADIO_SERVER_PORT 7860 +EXPOSE 7860 + +# Expose port 8000 for the API service +ENV API_PORT 8000 +EXPOSE 8000 diff --git a/docker/docker-npu/docker-compose.yml b/docker/docker-npu/docker-compose.yml new file mode 100644 index 00000000..657cba9f --- /dev/null +++ b/docker/docker-npu/docker-compose.yml @@ -0,0 +1,31 @@ +services: + llamafactory: + build: + dockerfile: ./docker/docker-npu/Dockerfile + context: ../.. + args: + INSTALL_DEEPSPEED: false + PIP_INDEX: https://pypi.org/simple + container_name: llamafactory + volumes: + - ../../hf_cache:/root/.cache/huggingface + - ../../ms_cache:/root/.cache/modelscope + - ../../data:/app/data + - ../../output:/app/output + - /usr/local/dcmi:/usr/local/dcmi + - /usr/local/bin/npu-smi:/usr/local/bin/npu-smi + - /usr/local/Ascend/driver:/usr/local/Ascend/driver + - /etc/ascend_install.info:/etc/ascend_install.info + ports: + - "7860:7860" + - "8000:8000" + ipc: host + tty: true + stdin_open: true + command: bash + devices: + - /dev/davinci0 + - /dev/davinci_manager + - /dev/devmm_svm + - /dev/hisi_hdc + restart: unless-stopped diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py index 4111d6b4..48442d50 100644 --- a/evaluation/ceval/ceval.py +++ b/evaluation/ceval/ceval.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py index 37efb328..5ff548a4 100644 --- a/evaluation/cmmlu/cmmlu.py +++ b/evaluation/cmmlu/cmmlu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py index a4530250..1065fb31 100644 --- a/evaluation/mmlu/mmlu.py +++ b/evaluation/mmlu/mmlu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets diff --git a/examples/README.md b/examples/README.md index f985d552..d5aca5ad 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory. ## Table of Contents -- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) -- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) -- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) -- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus) -- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus) +- [LoRA Fine-Tuning](#lora-fine-tuning) +- [QLoRA Fine-Tuning](#qlora-fine-tuning) +- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) - [Extras](#extras) +Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices. + ## Examples -### LoRA Fine-Tuning on A Single GPU +### LoRA Fine-Tuning #### (Continuous) Pre-Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml +llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml ``` #### Supervised Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### Multimodal Supervised Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml ``` #### Reward Modeling ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml +llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml ``` #### PPO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` #### DPO/ORPO/SimPO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ``` #### KTO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml +llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml ``` #### Preprocess Dataset @@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml +llamafactory-cli train examples/train_lora/llama3_preprocess.yaml ``` #### Evaluating on MMLU/CMMLU/C-Eval Benchmarks ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml +llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` #### Batch Predicting and Computing BLEU and ROUGE Scores ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml -``` - -### QLoRA Fine-Tuning on a Single GPU - -#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended) - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml -``` - -#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml -``` - -#### Supervised Fine-Tuning with 4-bit AWQ Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml -``` - -#### Supervised Fine-Tuning with 2-bit AQLM Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml -``` - -### LoRA Fine-Tuning on Multiple GPUs - -#### Supervised Fine-Tuning on Single Node - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml ``` #### Supervised Fine-Tuning on Multiple Nodes ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` -### LoRA Fine-Tuning on Multiple NPUs +### QLoRA Fine-Tuning -#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 +#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended) ```bash -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml +llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml ``` -### Full-Parameter Fine-Tuning on Multiple GPUs +#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml +``` + +#### Supervised Fine-Tuning with 4-bit AWQ Quantization + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml +``` + +#### Supervised Fine-Tuning with 2-bit AQLM Quantization + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml +``` + +### Full-Parameter Fine-Tuning #### Supervised Fine-Tuning on Single Node ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### Supervised Fine-Tuning on Multiple Nodes ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### Batch Predicting and Computing BLEU and ROUGE Scores ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml +llamafactory-cli train examples/train_full/llama3_full_predict.yaml ``` ### Merging LoRA Adapters and Quantization @@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` #### Quantizing Model using AutoGPTQ ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml +llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` ### Inferring LoRA Fine-Tuned Models -Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices. - #### Use CLI ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` #### Use Web UI ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml +llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml ``` #### Launch OpenAI-style API ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml +llamafactory-cli api examples/inference/llama3_lora_sft.yaml ``` ### Extras @@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y #### Full-Parameter Fine-Tuning using GaLore ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml +llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` #### Full-Parameter Fine-Tuning using BAdam ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml +llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` #### LoRA+ Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +``` + +#### PiSSA Fine-Tuning + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml ``` #### Mixture-of-Depths Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml +llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml ``` #### LLaMA-Pro Fine-Tuning ```bash bash examples/extras/llama_pro/expand.sh -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml +llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ``` #### FSDP+QLoRA Fine-Tuning ```bash -bash examples/extras/fsdp_qlora/single_node.sh +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index cf5bbf49..d96bf882 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -4,59 +4,59 @@ ## 目录 -- [单 GPU LoRA 微调](#单-gpu-lora-微调) -- [单 GPU QLoRA 微调](#单-gpu-qlora-微调) -- [多 GPU LoRA 微调](#多-gpu-lora-微调) -- [多 NPU LoRA 微调](#多-npu-lora-微调) -- [多 GPU 全参数微调](#多-gpu-全参数微调) +- [LoRA 微调](#lora-微调) +- [QLoRA 微调](#qlora-微调) +- [全参数微调](#全参数微调) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [推理 LoRA 模型](#推理-lora-模型) - [杂项](#杂项) +使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。 + ## 示例 -### 单 GPU LoRA 微调 +### LoRA 微调 #### (增量)预训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml +llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml ``` #### 指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### 多模态指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml ``` #### 奖励模型训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml +llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml ``` #### PPO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` #### DPO/ORPO/SimPO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ``` #### KTO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml +llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml ``` #### 预处理数据集 @@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml +llamafactory-cli train examples/train_lora/llama3_preprocess.yaml ``` #### 在 MMLU/CMMLU/C-Eval 上评估 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml +llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` #### 批量预测并计算 BLEU 和 ROUGE 分数 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml +llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml ``` -### 单 GPU QLoRA 微调 - -#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐) +#### 多机指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml -``` - -#### 基于 4/8 比特 GPTQ 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml -``` - -#### 基于 4 比特 AWQ 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml -``` - -#### 基于 2 比特 AQLM 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml -``` - -### 多 GPU LoRA 微调 - -#### 在单机上进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml -``` - -#### 在多机上进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### 使用 DeepSpeed ZeRO-3 平均分配显存 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` -### 多 NPU LoRA 微调 +### QLoRA 微调 -#### 使用 DeepSpeed ZeRO-0 进行指令监督微调 +#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐) ```bash -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml +llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml ``` -### 多 GPU 全参数微调 +#### 基于 4/8 比特 GPTQ 量化进行指令监督微调 + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml +``` + +#### 基于 4 比特 AWQ 量化进行指令监督微调 + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml +``` + +#### 基于 2 比特 AQLM 量化进行指令监督微调 + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml +``` + +### 全参数微调 #### 在单机上进行指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### 在多机上进行指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml -CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### 批量预测并计算 BLEU 和 ROUGE 分数 ```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml +llamafactory-cli train examples/train_full/llama3_full_predict.yaml ``` ### 合并 LoRA 适配器与模型量化 @@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` #### 使用 AutoGPTQ 量化模型 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml +llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` ### 推理 LoRA 模型 -使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。 - #### 使用命令行接口 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` #### 使用浏览器界面 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml +llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml ``` #### 启动 OpenAI 风格 API ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml +llamafactory-cli api examples/inference/llama3_lora_sft.yaml ``` ### 杂项 @@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y #### 使用 GaLore 进行全参数训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml +llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` #### 使用 BAdam 进行全参数训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml +llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` #### LoRA+ 微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +``` + +#### PiSSA 微调 + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml ``` #### 深度混合微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml +llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml ``` #### LLaMA-Pro 微调 ```bash bash examples/extras/llama_pro/expand.sh -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml +llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ``` #### FSDP+QLoRA 微调 ```bash -bash examples/extras/fsdp_qlora/single_node.sh +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/full_multi_gpu/llama3_full_sft.yaml b/examples/extras/badam/llama3_full_sft.yaml similarity index 81% rename from examples/full_multi_gpu/llama3_full_sft.yaml rename to examples/extras/badam/llama3_full_sft.yaml index 40b62f24..31d61c33 100644 --- a/examples/full_multi_gpu/llama3_full_sft.yaml +++ b/examples/extras/badam/llama3_full_sft.yaml @@ -5,10 +5,11 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: full - -### ddp -ddp_timeout: 180000000 -deepspeed: examples/deepspeed/ds_z3_config.json +use_badam: true +badam_mode: layer +badam_switch_mode: ascending +badam_switch_interval: 50 +badam_verbose: 2 ### dataset dataset: identity,alpaca_en_demo @@ -27,12 +28,11 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 -gradient_accumulation_steps: 2 +gradient_accumulation_steps: 8 learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true ### eval val_size: 0.1 diff --git a/examples/extras/badam/llama3_lora_sft.yaml b/examples/extras/badam/llama3_full_sft_ds3.yaml similarity index 91% rename from examples/extras/badam/llama3_lora_sft.yaml rename to examples/extras/badam/llama3_full_sft_ds3.yaml index a78de2fa..f2d7309f 100644 --- a/examples/extras/badam/llama3_lora_sft.yaml +++ b/examples/extras/badam/llama3_full_sft_ds3.yaml @@ -6,9 +6,11 @@ stage: sft do_train: true finetuning_type: full use_badam: true +badam_mode: layer badam_switch_mode: ascending badam_switch_interval: 50 badam_verbose: 2 +deepspeed: examples/deepspeed/ds_z3_config.json ### dataset dataset: identity,alpaca_en_demo @@ -32,7 +34,6 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -pure_bf16: true ### eval val_size: 0.1 diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml index 084269ef..6c80ef58 100644 --- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml +++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml @@ -8,9 +8,6 @@ do_train: true finetuning_type: lora lora_target: all -### ddp -ddp_timeout: 180000000 - ### dataset dataset: identity,alpaca_en_demo template: llama3 @@ -33,7 +30,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/extras/fsdp_qlora/single_node.sh b/examples/extras/fsdp_qlora/train.sh similarity index 100% rename from examples/extras/fsdp_qlora/single_node.sh rename to examples/extras/fsdp_qlora/train.sh diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml index 444a1113..5e7e90bb 100644 --- a/examples/extras/llama_pro/llama3_freeze_sft.yaml +++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml @@ -31,7 +31,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml index 1ba654ec..062a312b 100644 --- a/examples/extras/loraplus/llama3_lora_sft.yaml +++ b/examples/extras/loraplus/llama3_lora_sft.yaml @@ -30,7 +30,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml index df03c1e0..085febfc 100644 --- a/examples/extras/mod/llama3_full_sft.yaml +++ b/examples/extras/mod/llama3_full_sft.yaml @@ -31,6 +31,7 @@ num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 pure_bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml b/examples/extras/pissa/llama3_lora_sft.yaml similarity index 88% rename from examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml rename to examples/extras/pissa/llama3_lora_sft.yaml index b308dcab..05077b6c 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml +++ b/examples/extras/pissa/llama3_lora_sft.yaml @@ -1,12 +1,14 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct -quantization_bit: 4 ### method stage: sft do_train: true finetuning_type: lora lora_target: all +pissa_init: true +pissa_iter: 4 +pissa_convert: true ### dataset dataset: identity,alpaca_en_demo @@ -30,7 +32,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/full_multi_gpu/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml similarity index 100% rename from examples/full_multi_gpu/llama3_full_predict.yaml rename to examples/train_full/llama3_full_predict.yaml diff --git a/examples/lora_multi_gpu/llama3_lora_sft.yaml b/examples/train_full/llama3_full_sft_ds3.yaml similarity index 83% rename from examples/lora_multi_gpu/llama3_lora_sft.yaml rename to examples/train_full/llama3_full_sft_ds3.yaml index 348e53b9..c983ad5c 100644 --- a/examples/lora_multi_gpu/llama3_lora_sft.yaml +++ b/examples/train_full/llama3_full_sft_ds3.yaml @@ -4,11 +4,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct ### method stage: sft do_train: true -finetuning_type: lora -lora_target: all - -### ddp -ddp_timeout: 180000000 +finetuning_type: full +deepspeed: examples/deepspeed/ds_z3_config.json ### dataset dataset: identity,alpaca_en_demo @@ -19,7 +16,7 @@ overwrite_cache: true preprocessing_num_workers: 16 ### output -output_dir: saves/llama3-8b/lora/sft +output_dir: saves/llama3-8b/full/sft logging_steps: 10 save_steps: 500 plot_loss: true @@ -32,7 +29,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_lora_dpo.yaml b/examples/train_lora/llama3_lora_dpo.yaml similarity index 87% rename from examples/lora_single_gpu/llama3_lora_dpo.yaml rename to examples/train_lora/llama3_lora_dpo.yaml index 78344330..d87c0669 100644 --- a/examples/lora_single_gpu/llama3_lora_dpo.yaml +++ b/examples/train_lora/llama3_lora_dpo.yaml @@ -7,7 +7,7 @@ do_train: true finetuning_type: lora lora_target: all pref_beta: 0.1 -pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo] +pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] ### dataset dataset: dpo_en_demo @@ -31,7 +31,8 @@ learning_rate: 5.0e-6 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_lora_eval.yaml b/examples/train_lora/llama3_lora_eval.yaml similarity index 100% rename from examples/lora_single_gpu/llama3_lora_eval.yaml rename to examples/train_lora/llama3_lora_eval.yaml diff --git a/examples/lora_single_gpu/llama3_lora_kto.yaml b/examples/train_lora/llama3_lora_kto.yaml similarity index 93% rename from examples/lora_single_gpu/llama3_lora_kto.yaml rename to examples/train_lora/llama3_lora_kto.yaml index d5234c0a..08208c25 100644 --- a/examples/lora_single_gpu/llama3_lora_kto.yaml +++ b/examples/train_lora/llama3_lora_kto.yaml @@ -6,6 +6,7 @@ stage: kto do_train: true finetuning_type: lora lora_target: all +pref_beta: 0.1 ### dataset dataset: kto_en_demo @@ -29,7 +30,8 @@ learning_rate: 5.0e-6 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_lora_ppo.yaml b/examples/train_lora/llama3_lora_ppo.yaml similarity index 95% rename from examples/lora_single_gpu/llama3_lora_ppo.yaml rename to examples/train_lora/llama3_lora_ppo.yaml index 98c842f9..512e90ea 100644 --- a/examples/lora_single_gpu/llama3_lora_ppo.yaml +++ b/examples/train_lora/llama3_lora_ppo.yaml @@ -30,7 +30,8 @@ learning_rate: 1.0e-5 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### generate max_new_tokens: 512 diff --git a/examples/lora_single_gpu/llama3_lora_predict.yaml b/examples/train_lora/llama3_lora_predict.yaml similarity index 95% rename from examples/lora_single_gpu/llama3_lora_predict.yaml rename to examples/train_lora/llama3_lora_predict.yaml index a127d248..148c8635 100644 --- a/examples/lora_single_gpu/llama3_lora_predict.yaml +++ b/examples/train_lora/llama3_lora_predict.yaml @@ -22,3 +22,4 @@ overwrite_output_dir: true ### eval per_device_eval_batch_size: 1 predict_with_generate: true +ddp_timeout: 180000000 diff --git a/examples/lora_single_gpu/llama3_lora_pretrain.yaml b/examples/train_lora/llama3_lora_pretrain.yaml similarity index 94% rename from examples/lora_single_gpu/llama3_lora_pretrain.yaml rename to examples/train_lora/llama3_lora_pretrain.yaml index db435ca9..5e8aaaef 100644 --- a/examples/lora_single_gpu/llama3_lora_pretrain.yaml +++ b/examples/train_lora/llama3_lora_pretrain.yaml @@ -28,7 +28,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_lora_reward.yaml b/examples/train_lora/llama3_lora_reward.yaml similarity index 91% rename from examples/lora_single_gpu/llama3_lora_reward.yaml rename to examples/train_lora/llama3_lora_reward.yaml index 1ce42ea4..96c32238 100644 --- a/examples/lora_single_gpu/llama3_lora_reward.yaml +++ b/examples/train_lora/llama3_lora_reward.yaml @@ -25,11 +25,12 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 1.0e-5 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml similarity index 95% rename from examples/lora_single_gpu/llama3_lora_sft.yaml rename to examples/train_lora/llama3_lora_sft.yaml index 651b636f..55a8077e 100644 --- a/examples/lora_single_gpu/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -29,7 +29,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds0.yaml similarity index 97% rename from examples/lora_multi_npu/llama3_lora_sft_ds.yaml rename to examples/train_lora/llama3_lora_sft_ds0.yaml index a0ec8aa1..f1442faa 100644 --- a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml +++ b/examples/train_lora/llama3_lora_sft_ds0.yaml @@ -6,9 +6,6 @@ stage: sft do_train: true finetuning_type: lora lora_target: all - -### ddp -ddp_timeout: 180000000 deepspeed: examples/deepspeed/ds_z0_config.json ### dataset @@ -33,7 +30,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds3.yaml similarity index 97% rename from examples/lora_multi_gpu/llama3_lora_sft_ds.yaml rename to examples/train_lora/llama3_lora_sft_ds3.yaml index 1c432fa7..66e7007e 100644 --- a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml +++ b/examples/train_lora/llama3_lora_sft_ds3.yaml @@ -6,9 +6,6 @@ stage: sft do_train: true finetuning_type: lora lora_target: all - -### ddp -ddp_timeout: 180000000 deepspeed: examples/deepspeed/ds_z3_config.json ### dataset @@ -33,7 +30,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/lora_single_gpu/llama3_preprocess.yaml b/examples/train_lora/llama3_preprocess.yaml similarity index 100% rename from examples/lora_single_gpu/llama3_preprocess.yaml rename to examples/train_lora/llama3_preprocess.yaml diff --git a/examples/lora_single_gpu/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml similarity index 95% rename from examples/lora_single_gpu/llava1_5_lora_sft.yaml rename to examples/train_lora/llava1_5_lora_sft.yaml index df510a93..ec03f82c 100644 --- a/examples/lora_single_gpu/llava1_5_lora_sft.yaml +++ b/examples/train_lora/llava1_5_lora_sft.yaml @@ -30,7 +30,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml b/examples/train_qlora/llama3_lora_sft_aqlm.yaml similarity index 95% rename from examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml rename to examples/train_qlora/llama3_lora_sft_aqlm.yaml index d54d6af6..3519d46b 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml +++ b/examples/train_qlora/llama3_lora_sft_aqlm.yaml @@ -29,7 +29,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml b/examples/train_qlora/llama3_lora_sft_awq.yaml similarity index 95% rename from examples/qlora_single_gpu/llama3_lora_sft_awq.yaml rename to examples/train_qlora/llama3_lora_sft_awq.yaml index 5cef178a..df48669b 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml +++ b/examples/train_qlora/llama3_lora_sft_awq.yaml @@ -29,7 +29,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml b/examples/train_qlora/llama3_lora_sft_gptq.yaml similarity index 95% rename from examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml rename to examples/train_qlora/llama3_lora_sft_gptq.yaml index b950042e..61fa9bb4 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml +++ b/examples/train_qlora/llama3_lora_sft_gptq.yaml @@ -29,7 +29,8 @@ learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 -fp16: true +bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/train_qlora/llama3_lora_sft_otfq.yaml b/examples/train_qlora/llama3_lora_sft_otfq.yaml new file mode 100644 index 00000000..80a05768 --- /dev/null +++ b/examples/train_qlora/llama3_lora_sft_otfq.yaml @@ -0,0 +1,41 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +quantization_bit: 4 +quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)] + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/lora/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/requirements.txt b/requirements.txt index 9e00555e..7380add4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ accelerate>=0.30.1 peft>=0.11.1 trl>=0.8.6 gradio>=4.0.0 +pandas>=2.0.0 scipy einops sentencepiece @@ -17,3 +18,4 @@ matplotlib>=3.7.0 fire packaging pyyaml +numpy<2.0.0 diff --git a/scripts/cal_flops.py b/scripts/cal_flops.py index ac87e0ab..32526d89 100644 --- a/scripts/cal_flops.py +++ b/scripts/cal_flops.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Calculates the flops of pre-trained models. -# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 -# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ +# Copyright 2024 Microsoft Corporation and the LlamaFactory team. +# +# This code is inspired by the Microsoft's DeepSpeed library. +# https://www.deepspeed.ai/tutorials/flops-profiler/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import fire import torch @@ -17,6 +30,10 @@ def calculate_flops( seq_length: int = 256, flash_attn: str = "auto", ): + r""" + Calculates the flops of pre-trained models. + Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 + """ with get_accelerator().device(0): chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py index bfa32cc9..ad6992cb 100644 --- a/scripts/cal_lr.py +++ b/scripts/cal_lr.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. -# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 -# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py +# Copyright 2024 imoneoi and the LlamaFactory team. +# +# This code is inspired by the imoneoi's OpenChat library. +# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from typing import Literal @@ -32,6 +45,10 @@ def calculate_lr( cutoff_len: int = 1024, # i.e. maximum input length during training is_mistral: bool = False, # mistral model uses a smaller learning rate, ): + r""" + Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. + Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 + """ model_args, data_args, training_args, _, _ = get_train_args( dict( stage=stage, diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 387b756c..fb503629 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -1,6 +1,17 @@ # coding=utf-8 -# Calculates the ppl on the dataset of the pre-trained models. -# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from dataclasses import dataclass @@ -56,6 +67,10 @@ def cal_ppl( max_samples: Optional[int] = None, train_on_prompt: bool = False, ): + r""" + Calculates the ppl on the dataset of the pre-trained models. + Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json + """ model_args, data_args, training_args, finetuning_args, _ = get_train_args( dict( stage=stage, diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py index 7739dcf0..4cdf01e6 100644 --- a/scripts/length_cdf.py +++ b/scripts/length_cdf.py @@ -1,6 +1,17 @@ # coding=utf-8 -# Calculates the distribution of the input lengths in the dataset. -# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import defaultdict @@ -19,6 +30,10 @@ def length_cdf( template: str = "default", interval: int = 1000, ): + r""" + Calculates the distribution of the input lengths in the dataset. + Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default + """ model_args, data_args, training_args, _, _ = get_train_args( dict( stage="sft", diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index 727998ae..17bf6fc2 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. -# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 -# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py +# Copyright 2024 Tencent Inc. and the LlamaFactory team. +# +# This code is inspired by the Tencent's LLaMA-Pro library. +# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -37,6 +50,10 @@ def block_expansion( shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False, ): + r""" + Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. + Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 + """ config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) num_layers = getattr(config, "num_hidden_layers") setattr(config, "num_hidden_layers", num_layers + num_expand) @@ -103,7 +120,7 @@ def block_expansion( json.dump(index, f, indent=2, sort_keys=True) print("Model weights saved in {}".format(output_dir)) - print("Fine-tune this model with:") + print("- Fine-tune this model with:") print("model_name_or_path: {}".format(output_dir)) print("finetuning_type: freeze") print("freeze_trainable_layers: {}".format(num_expand)) diff --git a/scripts/llamafy_baichuan2.py b/scripts/llamafy_baichuan2.py index 1ae58879..19284f5f 100644 --- a/scripts/llamafy_baichuan2.py +++ b/scripts/llamafy_baichuan2.py @@ -1,8 +1,17 @@ # coding=utf-8 -# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. -# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output -# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py -# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str): def llamafy_baichuan2( input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False ): + r""" + Converts the Baichuan2-7B model in the same format as LLaMA2-7B. + Usage: python llamafy_baichuan2.py --input_dir input --output_dir output + Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied + """ try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/scripts/llamafy_qwen.py b/scripts/llamafy_qwen.py index 69cf3e8e..e5b59483 100644 --- a/scripts/llamafy_qwen.py +++ b/scripts/llamafy_qwen.py @@ -1,7 +1,17 @@ # coding=utf-8 -# Converts the Qwen models in the same format as LLaMA2. -# Usage: python llamafy_qwen.py --input_dir input --output_dir output -# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str): def llamafy_qwen( input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False ): + r""" + Converts the Qwen models in the same format as LLaMA2. + Usage: python llamafy_qwen.py --input_dir input --output_dir output + Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied + """ try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py index 7f244316..4d2c01b9 100644 --- a/scripts/loftq_init.py +++ b/scripts/loftq_init.py @@ -1,14 +1,25 @@ # coding=utf-8 -# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) -# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir -# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import fire -import torch -import torch.nn as nn from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer @@ -17,65 +28,61 @@ if TYPE_CHECKING: from transformers import PreTrainedModel -class Shell(nn.Module): - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = nn.Parameter(weight, requires_grad=False) - if bias is not None: - self.bias = nn.Parameter(bias, requires_grad=False) - - -def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: - for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}: - parent_name = ".".join(name.split(".")[:-1]) - child_name = name.split(".")[-1] - parent_module = model.get_submodule(parent_name) - child_module = getattr(parent_module, child_name) - base_layer = getattr(child_module, "base_layer") - weight = getattr(base_layer, "weight", None) - bias = getattr(base_layer, "bias", None) - setattr(parent_module, child_name, Shell(weight, bias)) - - print("Model unwrapped.") - - def quantize_loftq( model_name_or_path: str, - save_dir: str, - loftq_bits: Optional[int] = 4, - loftq_iter: Optional[int] = 1, - lora_alpha: Optional[int] = None, - lora_rank: Optional[int] = 16, - lora_target: Optional[str] = "q_proj,v_proj", - save_safetensors: Optional[bool] = False, + output_dir: str, + loftq_bits: int = 4, + loftq_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: tuple = ("q_proj", "v_proj"), + save_safetensors: bool = True, ): + r""" + Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) + Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir + """ + if isinstance(lora_target, str): + lora_target = [name.strip() for name in lora_target.split(",")] + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") + loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=True, r=lora_rank, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, - lora_dropout=0.1, - target_modules=[name.strip() for name in lora_target.split(",")], + lora_dropout=lora_dropout, + target_modules=lora_target, init_lora_weights="loftq", loftq_config=loftq_config, ) # Init LoftQ model - lora_model = get_peft_model(model, lora_config) - base_model: "PreTrainedModel" = lora_model.get_base_model() + print("Initializing LoftQ weights, it may be take several minutes, wait patiently.") + peft_model = get_peft_model(model, lora_config) + loftq_dir = os.path.join(output_dir, "loftq_init") # Save LoftQ model - setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) - setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) - lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) + setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir) + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again + peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(loftq_dir)) # Save base model - unwrap_model(base_model) - base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) - tokenizer.save_pretrained(save_dir) + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("- Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(loftq_dir)) + print("finetuning_type: lora") + print("quantization_bit: {}".format(loftq_bits)) if __name__ == "__main__": diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py new file mode 100644 index 00000000..ad9d161c --- /dev/null +++ b/scripts/pissa_init.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING + +import fire +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +def quantize_pissa( + model_name_or_path: str, + output_dir: str, + pissa_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: tuple = ("q_proj", "v_proj"), + save_safetensors: bool = True, +): + r""" + Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA) + Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir + """ + if isinstance(lora_target, str): + lora_target = [name.strip() for name in lora_target.split(",")] + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=lora_rank, + lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, + lora_dropout=lora_dropout, + target_modules=lora_target, + init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), + ) + + # Init PiSSA model + peft_model = get_peft_model(model, lora_config) + pissa_dir = os.path.join(output_dir, "pissa_init") + + # Save PiSSA model + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again + peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(pissa_dir)) + + # Save base model + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("- Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(pissa_dir)) + print("finetuning_type: lora") + print("pissa_init: false") + print("pissa_convert: true") + print("- and optionally with:") + print("quantization_bit: 4") + + +if __name__ == "__main__": + fire.Fire(quantize_pissa) diff --git a/scripts/test_toolcall.py b/scripts/test_toolcall.py index 7e460017..6f6fd06c 100644 --- a/scripts/test_toolcall.py +++ b/scripts/test_toolcall.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import Sequence diff --git a/setup.py b/setup.py index 405ac46e..d43c311c 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import re @@ -23,14 +37,16 @@ extra_require = { "torch": ["torch>=1.13.1"], "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], - "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], + "deepspeed": ["deepspeed>=0.10.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"], - "vllm": ["vllm>=0.4.3"], - "galore": ["galore-torch"], - "badam": ["badam"], - "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], + "hqq": ["hqq"], + "eetq": ["eetq"], + "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], + "vllm": ["vllm>=0.4.3"], + "galore": ["galore-torch"], + "badam": ["badam>=1.2.1"], "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], "dev": ["ruff", "pytest"], diff --git a/src/api.py b/src/api.py index 3655e393..0f925497 100644 --- a/src/api.py +++ b/src/api.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import uvicorn diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 78230937..9d732777 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Level: api, webui > chat, eval, train > data, model > hparams > extras from .cli import VERSION diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index 21edab2f..c1264617 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from contextlib import asynccontextmanager from typing import Optional diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 98957bc1..72b2ae50 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import base64 import io import json @@ -78,9 +92,11 @@ def _process_request( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): - name = message.tool_calls[0].function.name - arguments = message.tool_calls[0].function.arguments - content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) + tool_calls = [ + {"name": tool_call.function.name, "arguments": tool_call.function.arguments} + for tool_call in message.tool_calls + ] + content = json.dumps(tool_calls, ensure_ascii=False) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) elif isinstance(message.content, list): for input_item in message.content: @@ -104,7 +120,7 @@ def _process_request( if isinstance(tool_list, list) and len(tool_list): try: tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) - except Exception: + except json.JSONDecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: tools = None @@ -146,15 +162,17 @@ async def create_chat_completion_response( choices = [] for i, response in enumerate(responses): if tools: - result = chat_model.engine.template.format_tools.extract(response.response_text) + result = chat_model.engine.template.extract_tool(response.response_text) else: result = response.response_text - if isinstance(result, tuple): - name, arguments = result - function = Function(name=name, arguments=arguments) - tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) - response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) + if isinstance(result, list): + tool_calls = [] + for tool in result: + function = Function(name=tool[0], arguments=tool[1]) + tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) + + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) finish_reason = Finish.TOOL else: response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py index 5ad9a071..d1ac94de 100644 --- a/src/llamafactory/api/common.py +++ b/src/llamafactory/api/common.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from typing import TYPE_CHECKING, Any, Dict diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index 055fa781..a69132ea 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from enum import Enum, unique from typing import Any, Dict, List, Optional, Union diff --git a/src/llamafactory/chat/__init__.py b/src/llamafactory/chat/__init__.py index a1a79de6..07276d48 100644 --- a/src/llamafactory/chat/__init__.py +++ b/src/llamafactory/chat/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base_engine import BaseEngine from .chat_model import ChatModel diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 65b6c59c..ccdf4c92 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union @@ -36,11 +50,6 @@ class BaseEngine(ABC): generating_args: "GeneratingArguments", ) -> None: ... - @abstractmethod - async def start( - self, - ) -> None: ... - @abstractmethod async def chat( self, diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 281ef0c1..5c83fa67 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -1,3 +1,20 @@ +# Copyright 2024 THUDM and the LlamaFactory team. +# +# This code is inspired by the THUDM's ChatGLM implementation. +# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence @@ -14,7 +31,7 @@ if TYPE_CHECKING: from .base_engine import BaseEngine, Response -def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: +def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: asyncio.set_event_loop(loop) loop.run_forever() @@ -32,7 +49,6 @@ class ChatModel: self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() - asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) def chat( self, diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 28e6a409..22a24339 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import concurrent.futures import os @@ -40,11 +54,19 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.model = load_model( self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab self.generating_args = generating_args.to_dict() + try: + asyncio.get_event_loop() + except RuntimeError: + logger.warning("There is no current event loop, creating a new one.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) @staticmethod def _process_args( @@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine): return scores - async def start(self) -> None: - self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) - async def chat( self, messages: Sequence[Dict[str, str]], @@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine): image, input_kwargs, ) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: return await loop.run_in_executor(pool, self._chat, *input_args) @@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine): image, input_kwargs, ) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: stream = self._stream_chat(*input_args) while True: @@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine): loop = asyncio.get_running_loop() input_args = (self.model, self.tokenizer, batch_input, input_kwargs) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: return await loop.run_in_executor(pool, self._get_scores, *input_args) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 87ce8684..f0d23676 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -1,10 +1,24 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger from ..extras.misc import get_device_count -from ..extras.packages import is_vllm_available +from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5 from ..model import load_config, load_tokenizer from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response @@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - from vllm.sequence import MultiModalData + + if is_vllm_version_greater_than_0_5(): + from vllm.multimodal.image import ImagePixelData + else: + from vllm.sequence import MultiModalData if TYPE_CHECKING: @@ -41,14 +59,14 @@ class VllmEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.generating_args = generating_args.to_dict() engine_args = { "model": model_args.model_name_or_path, "trust_remote_code": True, "download_dir": model_args.cache_dir, - "dtype": model_args.vllm_dtype, + "dtype": model_args.infer_dtype, "max_model_len": model_args.vllm_maxlen, "tensor_parallel_size": get_device_count() or 1, "gpu_memory_utilization": model_args.vllm_gpu_util, @@ -106,7 +124,10 @@ class VllmEngine(BaseEngine): if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + if is_vllm_version_greater_than_0_5(): + multi_modal_data = ImagePixelData(image=pixel_values) + else: # TODO: remove vllm 0.4.3 support + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) else: multi_modal_data = None @@ -162,9 +183,6 @@ class VllmEngine(BaseEngine): ) return result_generator - async def start(self) -> None: - pass - async def chat( self, messages: Sequence[Dict[str, str]], diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 5042e53c..48eb2898 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import random import subprocess @@ -60,7 +74,7 @@ class Command(str, Enum): def main(): - command = sys.argv.pop(1) + command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP if command == Command.API: run_api() elif command == Command.CHAT: @@ -77,7 +91,7 @@ def main(): master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) - subprocess.run( + process = subprocess.run( ( "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" @@ -92,6 +106,7 @@ def main(): ), shell=True, ) + sys.exit(process.returncode) else: run_exp() elif command == Command.WEBDEMO: diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index b08691d3..307853bc 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding from .data_utils import Role, split_dataset from .loader import get_dataset diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 434956af..299bdca3 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Union @@ -10,6 +24,7 @@ from .data_utils import Role if TYPE_CHECKING: from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments from .parser import DatasetAttr @@ -175,7 +190,10 @@ def convert_sharegpt( def align_dataset( - dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: @@ -208,7 +226,7 @@ def align_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Converting format of dataset", ) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1dc8dd8d..e4859ff5 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Any, Dict, Sequence diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 9b313112..76ded47e 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -1,5 +1,19 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union from datasets import concatenate_datasets, interleave_datasets @@ -16,6 +30,9 @@ if TYPE_CHECKING: logger = get_logger(__name__) +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] + + @unique class Role(str, Enum): USER = "user" @@ -25,13 +42,6 @@ class Role(str, Enum): OBSERVATION = "observation" -def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: - max_target_len = int(max_len * (target_len / (source_len + target_len))) - max_target_len = max(max_target_len, reserved_label_len) - max_source_len = max_len - min(max_target_len, target_len) - return max_source_len, max_target_len - - def merge_dataset( all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 0cd3d6c1..c1653a76 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -1,83 +1,36 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union - -SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] - - -JSON_FORMAT_PROMPT = ( - """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)""" -) - - -TOOL_SYSTEM_PROMPT = ( - "You have access to the following tools:\n{tool_text}" - "Use the following format if using a tool:\n" - "```\n" - "Action: tool name (one of [{tool_names}]).\n" - "Action Input: the input to the tool{format_prompt}.\n" - "```\n" -) - - -def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: - tool_text = "" - tool_names = [] - for tool in tools: - param_text = "" - for name, param in tool["parameters"]["properties"].items(): - required = ", required" if name in tool["parameters"].get("required", []) else "" - enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" - items = ( - ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else "" - ) - param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( - name=name, - type=param.get("type", ""), - required=required, - desc=param.get("description", ""), - enum=enum, - items=items, - ) - - tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( - name=tool["name"], desc=tool.get("description", ""), args=param_text - ) - tool_names.append(tool["name"]) - - return TOOL_SYSTEM_PROMPT.format( - tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT - ) - - -def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) - action_match = re.search(regex, content) - if not action_match: - return content - - tool_name = action_match.group(1).strip() - tool_input = action_match.group(2).strip().strip('"').strip("```") - try: - arguments = json.loads(tool_input) - except json.JSONDecodeError: - return content - - return tool_name, json.dumps(arguments, ensure_ascii=False) +from .data_utils import SLOTS +from .tool_utils import DefaultToolUtils, GLM4ToolUtils @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) - tool_format: Optional[Literal["default"]] = None + tool_format: Optional[Literal["default", "glm4"]] = None @abstractmethod def apply(self, **kwargs) -> SLOTS: ... - def extract(self, content: str) -> Union[str, Tuple[str, str]]: + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: raise NotImplementedError @@ -128,34 +81,37 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): def __post_init__(self): - has_name, has_args = False, False - for slot in filter(lambda s: isinstance(s, str), self.slots): - if "{{name}}" in slot: - has_name = True - if "{{arguments}}" in slot: - has_args = True - - if not has_name or not has_args: - raise ValueError("Name and arguments placeholders are required in the function formatter.") + if self.tool_format == "default": + self.slots = DefaultToolUtils.get_function_slots() + self.slots + elif self.tool_format == "glm4": + self.slots = GLM4ToolUtils.get_function_slots() + self.slots + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") + functions: List[Tuple[str, str]] = [] try: - function = json.loads(content) - name = function["name"] - arguments = json.dumps(function["arguments"], ensure_ascii=False) - except Exception: - name, arguments = "", "" + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + functions = [] elements = [] - for slot in self.slots: - if isinstance(slot, str): - slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) - elements.append(slot) - elif isinstance(slot, (dict, set)): - elements.append(slot) - else: - raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + for name, arguments in functions: + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) return elements @@ -163,25 +119,22 @@ class FunctionFormatter(Formatter): @dataclass class ToolFormatter(Formatter): def __post_init__(self): - if self.tool_format is None: - raise ValueError("Tool format was not found.") + if self.tool_format == "default": + self._tool_formatter = DefaultToolUtils.tool_formatter + self._tool_extractor = DefaultToolUtils.tool_extractor + elif self.tool_format == "glm4": + self._tool_formatter = GLM4ToolUtils.tool_formatter + self._tool_extractor = GLM4ToolUtils.tool_extractor + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") try: tools = json.loads(content) - if not len(tools): - return [""] - - if self.tool_format == "default": - return [default_tool_formatter(tools)] - else: - raise NotImplementedError - except Exception: + return [self._tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: return [""] - def extract(self, content: str) -> Union[str, Tuple[str, str]]: - if self.tool_format == "default": - return default_tool_extractor(content) - else: - raise NotImplementedError + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: + return self._tool_extractor(content) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 2c236c76..8e7062db 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import os import sys @@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments, ModelArguments from .parser import DatasetAttr @@ -32,6 +45,7 @@ def load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None @@ -123,7 +137,7 @@ def load_single_dataset( max_samples = min(data_args.max_samples, len(dataset)) dataset = dataset.select(range(max_samples)) - return align_dataset(dataset, dataset_attr, data_args) + return align_dataset(dataset, dataset_attr, data_args, training_args) def get_dataset( @@ -134,7 +148,7 @@ def get_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: - template = get_template_and_fix_tokenizer(tokenizer, data_args.template) + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") @@ -157,7 +171,8 @@ def get_dataset( if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) + all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) + dataset = merge_dataset(all_datasets, data_args, training_args) with training_args.main_process_first(desc="pre-process dataset"): @@ -169,7 +184,7 @@ def get_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index ec97bfc1..4bebcd68 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from dataclasses import dataclass diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index cf207d7e..3a80900c 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple @@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments from .template import Template diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 98d83658..7ba05e23 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -1,13 +1,26 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -42,12 +55,8 @@ def _encode_feedback_example( else: kl_messages = prompt + [kl_response[1]] - prompt_ids, response_ids = template.encode_oneturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) - _, kl_response_ids = template.encode_oneturn( - tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) + _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) if template.efficient_eos: response_ids += [tokenizer.eos_token_id] @@ -57,6 +66,12 @@ def _encode_feedback_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + # do not consider the kl_response + source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) + prompt_ids = prompt_ids[:source_len] + response_ids = response_ids[:target_len] + kl_response_ids = kl_response_ids[:target_len] + input_ids = prompt_ids + response_ids labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids kl_input_ids = prompt_ids + kl_response_ids diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index fe984efa..c6001e6e 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -1,13 +1,26 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -31,12 +44,8 @@ def _encode_pairwise_example( chosen_messages = prompt + [response[0]] rejected_messages = prompt + [response[1]] - prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) - _, rejected_ids = template.encode_oneturn( - tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) + _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) if template.efficient_eos: chosen_ids += [tokenizer.eos_token_id] @@ -46,6 +55,13 @@ def _encode_pairwise_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + source_len, target_len = infer_seqlen( + len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len + ) # consider the response is more important + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + chosen_input_ids = prompt_ids + chosen_ids chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids rejected_input_ids = prompt_ids + rejected_ids diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 87727b55..67d6009b 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -1,9 +1,26 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from itertools import chain from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer from ...hparams import DataArguments @@ -12,7 +29,8 @@ def preprocess_pretrain_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] + eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token + text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] if not data_args.packing: if data_args.template == "gemma": diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 9903a053..455908ae 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -1,5 +1,19 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import bisect -from typing import TYPE_CHECKING, List, Sequence +from typing import TYPE_CHECKING, List, Sequence, Tuple from ...extras.packages import is_pillow_available @@ -62,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> """ image_seq_length = getattr(processor, "image_seq_length") return [0] * image_seq_length + [1] * (input_len - image_seq_length) + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + new_source_len = max(cutoff_len - new_target_len, 0) + return new_source_len, new_target_len diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 35640174..8ef55321 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,14 +1,27 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -38,10 +51,17 @@ def _encode_supervised_example( input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - encoded_pairs = template.encode_multiturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) + total_length = 1 if template.efficient_eos else 0 for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + if data_args.train_on_prompt: source_mask = source_ids elif turn_idx != 0 and template.efficient_eos: diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index f711eeac..b3fc85c9 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -1,13 +1,26 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.logging import get_logger from ..data_utils import Role -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -34,9 +47,7 @@ def _encode_unsupervised_example( else: messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] - input_ids, labels = template.encode_oneturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) if template.efficient_eos: labels += [tokenizer.eos_token_id] @@ -44,6 +55,9 @@ def _encode_unsupervised_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len) + input_ids = input_ids[:source_len] + labels = labels[:target_len] return input_ids, labels diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b600c567..aefd5195 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1,8 +1,22 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from ..extras.logging import get_logger -from .data_utils import Role, infer_max_len +from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter @@ -24,69 +38,74 @@ class Template: format_observation: "Formatter" format_tools: "Formatter" format_separator: "Formatter" + format_prefix: "Formatter" default_system: str stop_words: List[str] image_token: str efficient_eos: bool replace_eos: bool - force_system: bool def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - cutoff_len: int = 1_000_000, - reserved_label_len: int = 1, ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + encoded_messages = self._encode(tokenizer, messages, system, tools) prompt_ids = [] - for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids += query_ids + resp_ids - prompt_ids = prompt_ids + encoded_pairs[-1][0] - answer_ids = encoded_pairs[-1][1] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + answer_ids = encoded_messages[-1] return prompt_ids, answer_ids def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - cutoff_len: int = 1_000_000, - reserved_label_len: int = 1, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: + r""" + Extracts tool message. + """ + return self.format_tools.extract(content) def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[List[int]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp - Turn t: sep + query resp + Turn 0: prefix + system + query resp + Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] - if i == 0 and (system or tools or self.force_system): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - elements += self.format_system.apply(content=(system + tool_text)) - elif i > 0 and i % 2 == 0: + + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + + if i > 0 and i % 2 == 0: elements += self.format_separator.apply() if message["role"] == Role.USER.value: @@ -102,11 +121,9 @@ class Template: encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + return encoded_messages - def _convert_elements_to_ids( - self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] - ) -> List[int]: + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: r""" Converts elements to token ids. """ @@ -127,57 +144,34 @@ class Template: return token_ids - def _make_pairs( - self, - encoded_messages: Sequence[List[int]], - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: - encoded_pairs = [] - total_length = 0 - for i in range(0, len(encoded_messages), 2): - if total_length >= cutoff_len: - break - - max_source_len, max_target_len = infer_max_len( - source_len=len(encoded_messages[i]), - target_len=len(encoded_messages[i + 1]), - max_len=(cutoff_len - total_length), - reserved_label_len=reserved_label_len, - ) - source_ids = encoded_messages[i][:max_source_len] - target_ids = encoded_messages[i + 1][:max_target_len] - total_length += len(source_ids) + len(target_ids) - encoded_pairs.append((source_ids, target_ids)) - - return encoded_pairs - @dataclass class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: str, tools: str, - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[List[int]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp - Turn t: sep + query resp + Turn 0: prefix + system + query resp + Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] + system_text = "" - if i == 0 and (system or tools or self.force_system): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - system_text = self.format_system.apply(content=(system + tool_text))[0] - elif i > 0 and i % 2 == 0: + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + + if i > 0 and i % 2 == 0: elements += self.format_separator.apply() if message["role"] == Role.USER.value: @@ -193,7 +187,7 @@ class Llama2Template(Template): encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + return encoded_messages TEMPLATES: Dict[str, Template] = {} @@ -208,12 +202,12 @@ def _register_template( format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, default_system: str = "", - stop_words: List[str] = [], + stop_words: Sequence[str] = [], image_token: str = "", efficient_eos: bool = False, replace_eos: bool = False, - force_system: bool = False, ) -> None: r""" Registers a chat template. @@ -245,9 +239,10 @@ def _register_template( template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) + default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() + default_prefix_formatter = EmptyFormatter() TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, format_assistant=format_assistant or default_assistant_formatter, @@ -256,12 +251,12 @@ def _register_template( format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, format_separator=format_separator or default_separator_formatter, + format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, image_token=image_token, efficient_eos=efficient_eos, replace_eos=replace_eos, - force_system=force_system, ) @@ -307,6 +302,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: jinja_template = "" + prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) + if prefix: + jinja_template += "{{ " + prefix + " }}" + if template.default_system: jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" @@ -315,11 +314,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") ) system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") - if isinstance(template, Llama2Template): - pass - elif template.force_system: - jinja_template += "{{ " + system_message + " }}" - else: + if not isinstance(template, Llama2Template): jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% for message in messages %}" @@ -346,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, + tool_format: Optional[str] = None, ) -> Template: if name is None: template = TEMPLATES["empty"] # placeholder @@ -354,6 +350,12 @@ def get_template_and_fix_tokenizer( if template is None: raise ValueError("Template {} does not exist.".format(name)) + if tool_format is not None: + logger.info("Using tool format: {}.".format(tool_format)) + eos_slots = [] if template.efficient_eos else [{"eos_token"}] + template.format_tools = ToolFormatter(tool_format=tool_format) + template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) + stop_words = template.stop_words if template.replace_eos: if not stop_words: @@ -435,9 +437,8 @@ _register_template( _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_separator=EmptyFormatter(slots=["\n\n"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -450,11 +451,7 @@ _register_template( _register_template( name="breeze", format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - default_system=( - "You are a helpful AI assistant built by MediaTek Research. " - "The user you are helping speaks Traditional Chinese and comes from Taiwan." - ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), efficient_eos=True, ) @@ -462,10 +459,9 @@ _register_template( _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), efficient_eos=True, - force_system=True, ) @@ -473,32 +469,13 @@ _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), - stop_words=["<|user|>", "<|observation|>"], - efficient_eos=True, - force_system=True, -) - - -_register_template( - name="chatglm3_system", - format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), - format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter( - slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] - ), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), - format_observation=StringFormatter( - slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] - ), - default_system=( - "You are ChatGLM3, a large language model trained by Zhipu.AI. " - "Follow the user's instructions carefully. Respond using markdown." - ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, ) @@ -529,8 +506,7 @@ _register_template( _register_template( name="codegeex2", - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), ) @@ -544,21 +520,15 @@ _register_template( ) ] ), - format_system=StringFormatter( - slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"] - ), - default_system=( - "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users " - "by providing thorough responses. You are trained by Cohere." - ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) _register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -591,30 +561,28 @@ _register_template( _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), - format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( "You are an AI programming assistant, utilizing the Deepseek Coder model, " "developed by Deepseek Company, and you only answer questions related to computer science. " "For politically sensitive questions, security and privacy issues, " "and other non-computer science questions, you will refuse to answer\n" ), - stop_words=["<|EOT|>"], - efficient_eos=True, ) _register_template( name="default", - format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), + format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), format_system=StringFormatter(slots=["{{content}}\n"]), format_separator=EmptyFormatter(slots=["\n"]), ) @@ -622,11 +590,7 @@ _register_template( _register_template( name="empty", - format_user=StringFormatter(slots=["{{content}}"]), - format_assistant=StringFormatter(slots=["{{content}}"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), efficient_eos=True, - force_system=True, ) @@ -648,13 +612,12 @@ _register_template( _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] ), format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), efficient_eos=True, - force_system=True, ) @@ -662,36 +625,33 @@ _register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), - format_system=StringFormatter(slots=["[gMASK]{{content}}"]), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, - force_system=True, ) _register_template( name="intern", - format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), - format_separator=EmptyFormatter(slots=[{"token": ""}, "\n"]), + format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], - efficient_eos=True, + efficient_eos=True, # internlm tokenizer cannot set eos_token_id ) _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_separator=EmptyFormatter(slots=["\n"]), - default_system=( - "You are an AI assistant whose name is InternLM (书生·浦语).\n" - "- InternLM (书生·浦语) is a conversational language model that is developed " - "by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" - "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen " - "by the user such as English and 中文." - ), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|im_end|>"], efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) @@ -700,7 +660,6 @@ _register_template( _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), - format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), ) @@ -723,9 +682,7 @@ _register_template( ) ] ), - format_system=StringFormatter( - slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"] - ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), format_observation=StringFormatter( slots=[ ( @@ -734,7 +691,7 @@ _register_template( ) ] ), - default_system="You are a helpful assistant.", + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|eot_id|>"], replace_eos=True, ) @@ -743,24 +700,21 @@ _register_template( _register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), - format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"eos_token"}]), ) _register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -774,27 +728,25 @@ _register_template( ) ] ), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|eot_id|>"], replace_eos=True, - force_system=True, ) _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_separator=EmptyFormatter(slots=["\n"]), - default_system="You are a helpful AI assistant.", + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|end|>"], replace_eos=True, ) @@ -827,7 +779,6 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|end|>"], replace_eos=True, - force_system=True, ) diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py new file mode 100644 index 00000000..ac5565d5 --- /dev/null +++ b/src/llamafactory/data/tool_utils.py @@ -0,0 +1,140 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +from .data_utils import SLOTS + + +DEFAULT_TOOL_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [{tool_names}]).\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" + "```\n" +) + + +GLM4_TOOL_PROMPT = ( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" +) + + +@dataclass +class ToolUtils(ABC): + @staticmethod + @abstractmethod + def get_function_slots() -> SLOTS: ... + + @staticmethod + @abstractmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... + + +class DefaultToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["Action: {{name}}\nAction Input: {{arguments}}\n"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required, enum, items = "", "", "" + if name in tool["parameters"].get("required", []): + required = ", required" + + if param.get("enum", None): + enum = ", should be one of [{}]".format(", ".join(param["enum"])) + + if param.get("items", None): + items = ", where each item should be {}".format(param["items"].get("type", "")) + + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + items=items, + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) + action_match: List[Tuple[str, str]] = re.findall(regex, content) + if not action_match: + return content + + results = [] + for match in action_match: + tool_name = match[0].strip() + tool_input = match[1].strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results + + +class GLM4ToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["{{name}}\n{{arguments}}"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + if "\n" not in content: + return content + + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index 192f4815..d3140793 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -1,4 +1,41 @@ -# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the Dan's test library. +# https://github.com/hendrycks/test/blob/master/evaluate_flan.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2020 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. import inspect import json @@ -26,9 +63,7 @@ class Evaluator: self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.eval_template = get_eval_template(self.eval_args.lang) - self.choice_inputs = [ - self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES - ] + self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] @torch.inference_mode() def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py index a4a6ef0e..7d524e7c 100644 --- a/src/llamafactory/eval/template.py +++ b/src/llamafactory/eval/template.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Dict, List, Sequence, Tuple @@ -10,7 +24,6 @@ class EvalTemplate: system: str choice: str answer: str - prefix: str def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: r""" @@ -42,8 +55,8 @@ class EvalTemplate: eval_templates: Dict[str, "EvalTemplate"] = {} -def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: - eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) +def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: + eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) def get_eval_template(name: str) -> "EvalTemplate": @@ -56,8 +69,7 @@ _register_eval_template( name="en", system="The following are multiple choice questions (with answers) about {subject}.\n\n", choice="\n{choice}. {content}", - answer="\nAnswer: ", - prefix=" ", + answer="\nAnswer:", ) @@ -66,5 +78,4 @@ _register_eval_template( system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", choice="\n{choice}. {content}", answer="\n答案:", - prefix=" ", ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 466b1269..6029d84f 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import OrderedDict, defaultdict from enum import Enum from typing import Dict, Optional @@ -404,6 +418,18 @@ register_model_group( DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat", }, + "DeepSeek-MoE-Coder-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base", + }, + "DeepSeek-MoE-Coder-236B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base", + }, + "DeepSeek-MoE-Coder-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + }, + "DeepSeek-MoE-Coder-236B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct", + }, }, template="deepseek", ) @@ -496,6 +522,18 @@ register_model_group( "Gemma-1.1-7B-Chat": { DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", }, + "Gemma-2-9B": { + DownloadSource.DEFAULT: "google/gemma-2-9b", + }, + "Gemma-2-27B": { + DownloadSource.DEFAULT: "google/gemma-2-27b", + }, + "Gemma-2-9B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-9b-it", + }, + "Gemma-2-27B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-27b-it", + }, }, template="gemma", ) @@ -568,7 +606,7 @@ register_model_group( register_model_group( models={ - "Jambda-v0.1": { + "Jamba-v0.1": { DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1", } @@ -683,6 +721,21 @@ register_model_group( ) +register_model_group( + models={ + "MiniCPM-2B-SFT-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16", + }, + "MiniCPM-2B-DPO-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16", + }, + }, + template="cpm", +) + + register_model_group( models={ "Mistral-7B-v0.1": { diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index 1d4e43f1..14876048 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import platform import accelerate @@ -9,7 +26,7 @@ import trl from transformers.utils import is_torch_cuda_available, is_torch_npu_available -VERSION = "0.8.1.dev0" +VERSION = "0.8.3.dev0" def print_env() -> None: diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 430b8a48..67622212 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import sys diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index fc33f77e..20c752c5 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -1,13 +1,29 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import gc import os -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Tuple import torch -from peft import PeftModel -from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel +import transformers.dynamic_module_utils +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +from transformers.dynamic_module_utils import get_relative_imports from transformers.utils import ( - SAFE_WEIGHTS_NAME, - WEIGHTS_NAME, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_mps_available, @@ -16,7 +32,6 @@ from transformers.utils import ( ) from transformers.utils.versions import require_version -from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from .logging import get_logger @@ -28,8 +43,6 @@ except Exception: if TYPE_CHECKING: - from trl import AutoModelForCausalLMWithValueHead - from ..hparams import ModelArguments @@ -58,6 +71,9 @@ class AverageMeter: def check_dependencies() -> None: + r""" + Checks the version of the required packages. + """ if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: @@ -68,7 +84,7 @@ def check_dependencies() -> None: require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") -def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: +def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. """ @@ -79,7 +95,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel - # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize if param.__class__.__name__ == "Params4bit": if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): num_bytes = param.quant_storage.itemsize @@ -97,55 +113,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param -def fix_valuehead_checkpoint( - model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool -) -> None: - r""" - The model is already unwrapped. - - There are three cases: - 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} - 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} - 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} - - We assume `stage3_gather_16bit_weights_on_model_save=true`. - """ - if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): - return - - if safe_serialization: - from safetensors import safe_open - from safetensors.torch import save_file - - path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) - with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: - state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} - else: - path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) - state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") - - decoder_state_dict = {} - v_head_state_dict = {} - for name, param in state_dict.items(): - if name.startswith("v_head."): - v_head_state_dict[name] = param - else: - decoder_state_dict[name.replace("pretrained_model.", "")] = param - - os.remove(path_to_checkpoint) - model.pretrained_model.save_pretrained( - output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization - ) - - if safe_serialization: - save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) - else: - torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) - - logger.info("Value head model saved at: {}".format(output_dir)) - - -def get_current_device() -> torch.device: +def get_current_device() -> "torch.device": r""" Gets the current available device. """ @@ -184,7 +152,14 @@ def get_logits_processor() -> "LogitsProcessorList": return logits_processor -def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: +def has_tokenized_data(path: "os.PathLike") -> bool: + r""" + Checks if the path has a tokenized dataset. + """ + return os.path.isdir(path) and len(os.listdir(path)) > 0 + + +def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": r""" Infers the optimal dtype according to the model_dtype and device compatibility. """ @@ -203,11 +178,9 @@ def is_gpu_or_npu_available() -> bool: return is_torch_npu_available() or is_torch_cuda_available() -def has_tokenized_data(path: os.PathLike) -> bool: - r""" - Checks if the path has a tokenized dataset. - """ - return os.path.isdir(path) and len(os.listdir(path)) > 0 +def skip_check_imports() -> None: + if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + transformers.dynamic_module_utils.check_imports = get_relative_imports def torch_gc() -> None: diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 4c9e6492..0a84a293 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -1,5 +1,23 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import importlib.metadata import importlib.util +from functools import lru_cache from typing import TYPE_CHECKING from packaging import version @@ -24,10 +42,6 @@ def is_fastapi_available(): return _is_package_available("fastapi") -def is_flash_attn2_available(): - return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0") - - def is_galore_available(): return _is_package_available("galore_torch") @@ -36,18 +50,10 @@ def is_gradio_available(): return _is_package_available("gradio") -def is_jieba_available(): - return _is_package_available("jieba") - - def is_matplotlib_available(): return _is_package_available("matplotlib") -def is_nltk_available(): - return _is_package_available("nltk") - - def is_pillow_available(): return _is_package_available("PIL") @@ -60,10 +66,6 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") -def is_sdpa_available(): - return _get_package_version("torch") > version.parse("2.1.1") - - def is_starlette_available(): return _is_package_available("sse_starlette") @@ -74,3 +76,8 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") + + +@lru_cache +def is_vllm_version_greater_than_0_5(): + return _get_package_version("vllm") >= version.parse("0.5.0") diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py index dea23bbe..596d55e7 100644 --- a/src/llamafactory/extras/ploting.py +++ b/src/llamafactory/extras/ploting.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import math import os diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index d1ee98dd..cfe448c1 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index d2d53ec8..e351fccf 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field from typing import Literal, Optional @@ -28,10 +45,6 @@ class DataArguments: default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, ) - reserved_label_len: int = field( - default=1, - metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."}, - ) train_on_prompt: bool = field( default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}, @@ -90,15 +103,16 @@ class DataArguments: "help": "Whether or not to pack the sequences without cross-contamination attention for efficient training." }, ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Tool format to use for constructing function calling examples."}, + ) tokenized_path: Optional[str] = field( default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) def __post_init__(self): - if self.reserved_label_len >= self.cutoff_len: - raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.") - if self.streaming and self.val_size > 1e-6 and self.val_size < 1: raise ValueError("Streaming mode should have an integer val size.") diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py index 5a05f6f6..a7f221ca 100644 --- a/src/llamafactory/hparams/evaluation_args.py +++ b/src/llamafactory/hparams/evaluation_args.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from dataclasses import dataclass, field from typing import Literal, Optional diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 08af31e4..3867c0ec 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -1,5 +1,19 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import List, Literal, Optional @dataclass @@ -94,6 +108,18 @@ class LoraArguments: default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=16, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) create_new_adapter: bool = field( default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, @@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA return [item.strip() for item in arg.split(",")] return arg - self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules) - self.freeze_extra_modules = split_arg(self.freeze_extra_modules) - self.lora_alpha = self.lora_alpha or self.lora_rank * 2 - self.lora_target = split_arg(self.lora_target) - self.additional_target = split_arg(self.additional_target) - self.galore_target = split_arg(self.galore_target) + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + self.lora_target: List[str] = split_arg(self.lora_target) + self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + self.galore_target: List[str] = split_arg(self.galore_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - self.use_ref_model = self.pref_loss not in ["orpo", "simpo"] - if self.stage == "ppo" and self.reward_model is None: raise ValueError("`reward_model` is necessary for PPO training.") @@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + if self.pissa_init and self.finetuning_type != "lora": + raise ValueError("`pissa_init` is only valid for LoRA training.") + + if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): + raise ValueError("Cannot use PiSSA for current training stage.") + if self.train_mm_proj_only and self.finetuning_type != "full": raise ValueError("`train_mm_proj_only` is only valid for full training.") diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 0ee17d1a..7ebb4eed 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 6352a420..087c8c38 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,5 +1,28 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union + +from typing_extensions import Self + + +if TYPE_CHECKING: + import torch @dataclass @@ -22,6 +45,10 @@ class ModelArguments: ) }, ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, @@ -50,6 +77,10 @@ class ModelArguments: default=True, metadata={"help": "Whether or not to use memory-efficient model loading."}, ) + quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( + default="bitsandbytes", + metadata={"help": "Quantization method to use for on-the-fly quantization."}, + ) quantization_bit: Optional[int] = field( default=None, metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, @@ -70,7 +101,7 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( default="auto", metadata={"help": "Enable FlashAttention for faster training and inference."}, ) @@ -127,13 +158,9 @@ class ModelArguments: metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, ) vllm_max_lora_rank: int = field( - default=8, + default=32, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, ) - vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( - default="auto", - metadata={"help": "Data type for model weights and activations in the vLLM engine."}, - ) offload_folder: str = field( default="offload", metadata={"help": "Path to offload model weights."}, @@ -142,6 +169,10 @@ class ModelArguments: default=True, metadata={"help": "Whether or not to use KV cache in generation."}, ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, @@ -192,9 +223,9 @@ class ModelArguments: ) def __post_init__(self): - self.compute_dtype = None - self.device_map = None - self.model_max_length = None + self.compute_dtype: Optional["torch.dtype"] = None + self.device_map: Optional[Union[str, Dict[str, Any]]] = None + self.model_max_length: Optional[int] = None if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") @@ -208,11 +239,18 @@ class ModelArguments: if self.new_special_tokens is not None: # support multiple special tokens self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] - assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." - if self.export_quantization_bit is not None and self.export_quantization_dataset is None: raise ValueError("Quantization dataset is necessary for exporting.") def to_dict(self) -> Dict[str, Any]: return asdict(self) + + @classmethod + def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + arg_dict = old_arg.to_dict() + arg_dict.update(**kwargs) + new_arg = cls(**arg_dict) + new_arg.compute_dtype = old_arg.compute_dtype + new_arg.device_map = old_arg.device_map + new_arg.model_max_length = old_arg.model_max_length + return new_arg diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ff1fbf5d..8b2ea4c1 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import sys @@ -8,6 +25,7 @@ import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint +from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.versions import require_version @@ -65,13 +83,13 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Adapter is only valid for the LoRA method.") - if model_args.use_unsloth and is_deepspeed_zero3_enabled(): - raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") - if model_args.quantization_bit is not None: if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") + if finetuning_args.pissa_init: + raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") + if model_args.resize_vocab: raise ValueError("Cannot resize embedding layers of a quantized model.") @@ -100,7 +118,7 @@ def _check_extra_dependencies( require_version("galore_torch", "To fix: pip install galore_torch") if finetuning_args.use_badam: - require_version("badam", "To fix: pip install badam") + require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") if finetuning_args.plot_loss: require_version("matplotlib", "To fix: pip install matplotlib") @@ -162,6 +180,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("PPO only accepts wandb or tensorboard logger.") + if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: + raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") + + if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.") + if training_args.max_steps == -1 and data_args.streaming: raise ValueError("Please specify `max_steps` in streaming mode.") @@ -171,32 +195,31 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.do_train and model_args.quantization_device_map == "auto": raise ValueError("Cannot use device map for quantized models in training.") - if finetuning_args.use_dora and model_args.use_unsloth: - raise ValueError("Unsloth does not support DoRA.") + if finetuning_args.pissa_init and is_deepspeed_zero3_enabled(): + raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.") if finetuning_args.pure_bf16: if not is_torch_bf16_gpu_available(): raise ValueError("This device does not support `pure_bf16`.") - if training_args.fp16 or training_args.bf16: - raise ValueError("Turn off mixed precision training when using `pure_bf16`.") + if is_deepspeed_zero3_enabled(): + raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.") if ( finetuning_args.use_galore and finetuning_args.galore_layerwise - and training_args.parallel_mode.value == "distributed" + and training_args.parallel_mode == ParallelMode.DISTRIBUTED ): raise ValueError("Distributed training does not support layer-wise GaLore.") - if ( - finetuning_args.use_badam - and finetuning_args.badam_mode == "layer" - and training_args.parallel_mode.value == "distributed" - ): - raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") + if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: + if finetuning_args.badam_mode == "ratio": + raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") + elif not is_deepspeed_zero3_enabled(): + raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") - if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: - raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") + if finetuning_args.use_galore and training_args.deepspeed is not None: + raise ValueError("GaLore is incompatible with DeepSpeed yet.") if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") @@ -204,6 +227,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if model_args.visual_inputs and data_args.packing: raise ValueError("Cannot use packing in MLLM fine-tuning.") + if model_args.use_unsloth and is_deepspeed_zero3_enabled(): + raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) @@ -233,7 +259,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Post-process training arguments if ( - training_args.parallel_mode.value == "distributed" + training_args.parallel_mode == ParallelMode.DISTRIBUTED and training_args.ddp_find_unused_parameters is None and finetuning_args.finetuning_type == "lora" ): @@ -293,7 +319,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: training_args.local_rank, training_args.device, training_args.n_gpu, - training_args.parallel_mode.value == "distributed", + training_args.parallel_mode == ParallelMode.DISTRIBUTED, str(model_args.compute_dtype), ) ) @@ -332,6 +358,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if model_args.export_dir is not None and model_args.export_device == "cpu": model_args.device_map = {"": torch.device("cpu")} + model_args.model_max_length = data_args.cutoff_len else: model_args.device_map = "auto" diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index de154db9..65e0b68f 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from llamafactory.train.tuner import run_exp diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py index 9d23d59f..48cfe76c 100644 --- a/src/llamafactory/model/__init__.py +++ b/src/llamafactory/model/__init__.py @@ -1,9 +1,25 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .loader import load_config, load_model, load_tokenizer from .model_utils.misc import find_all_linear_modules +from .model_utils.quantization import QuantizationMethod from .model_utils.valuehead import load_valuehead_params __all__ = [ + "QuantizationMethod", "load_config", "load_model", "load_tokenizer", diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index f4e501a7..7caef9cc 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import TYPE_CHECKING @@ -25,8 +39,12 @@ def _setup_full_tuning( model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", + is_trainable: bool, cast_trainable_params_to_fp32: bool, ) -> None: + if not is_trainable: + return + logger.info("Fine-tuning method: Full") forbidden_modules = set() if model_args.visual_inputs and finetuning_args.freeze_vision_tower: @@ -47,8 +65,12 @@ def _setup_freeze_tuning( model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", + is_trainable: bool, cast_trainable_params_to_fp32: bool, ) -> None: + if not is_trainable: + return + logger.info("Fine-tuning method: Freeze") if model_args.visual_inputs: config = model.config.text_config @@ -132,7 +154,9 @@ def _setup_lora_tuning( is_trainable: bool, cast_trainable_params_to_fp32: bool, ) -> "PeftModel": - logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + if is_trainable: + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + adapter_to_resume = None if model_args.adapter_name_or_path is not None: @@ -155,8 +179,16 @@ def _setup_lora_tuning( else: adapter_to_merge = model_args.adapter_name_or_path + init_kwargs = { + "subfolder": model_args.adapter_folder, + "offload_folder": model_args.offload_folder, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + for adapter in adapter_to_merge: - model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder) + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) model = model.merge_and_unload() if len(adapter_to_merge) > 0: @@ -166,12 +198,9 @@ def _setup_lora_tuning( if model_args.use_unsloth: model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) else: - model = PeftModel.from_pretrained( - model, - adapter_to_resume, - is_trainable=is_trainable, - offload_folder=model_args.offload_folder, - ) + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) + + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) if is_trainable and adapter_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": @@ -209,16 +238,24 @@ def _setup_lora_tuning( "lora_alpha": finetuning_args.lora_alpha, "lora_dropout": finetuning_args.lora_dropout, "use_rslora": finetuning_args.use_rslora, + "use_dora": finetuning_args.use_dora, "modules_to_save": finetuning_args.additional_target, } if model_args.use_unsloth: model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: + if finetuning_args.pissa_init: + if finetuning_args.pissa_iter == -1: + logger.info("Using PiSSA initialization.") + peft_kwargs["init_lora_weights"] = "pissa" + else: + logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) + peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) + lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, - use_dora=finetuning_args.use_dora, **peft_kwargs, ) model = get_peft_model(model, lora_config) @@ -227,9 +264,6 @@ def _setup_lora_tuning( for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.float32) - if model_args.adapter_name_or_path is not None: - logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) - return model @@ -247,29 +281,36 @@ def init_adapter( Note that the trainable parameters must be cast to float32. """ - if (not is_trainable) and model_args.adapter_name_or_path is None: - logger.info("Adapter is not found at evaluation, load the base model.") - return model + if is_trainable and getattr(model, "quantization_method", None) is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantized models can only be used for the LoRA tuning.") - if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): - raise ValueError("You can only use lora for quantized models.") + if finetuning_args.pissa_init: + raise ValueError("Cannot initialize PiSSA adapter on quantized models.") - if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: - logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") - cast_trainable_params_to_fp32 = False + # cast trainable parameters to float32 if: + # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) + # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) + cast_trainable_params_to_fp32 = False + if not is_trainable: + pass + elif finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") + elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") else: logger.info("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True - if is_trainable and finetuning_args.finetuning_type == "full": - _setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32) - - if is_trainable and finetuning_args.finetuning_type == "freeze": - _setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32) - - if finetuning_args.finetuning_type == "lora": + if finetuning_args.finetuning_type == "full": + _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "freeze": + _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "lora": model = _setup_lora_tuning( config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 ) + else: + raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) return model diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 026a09be..43e65d52 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -1,10 +1,25 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger -from ..extras.misc import count_parameters, try_download_model_from_ms +from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from .adapter import init_adapter from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model @@ -33,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: Note: including inplace operation of model_args. """ + skip_check_imports() model_args.model_name_or_path = try_download_model_from_ms(model_args) return { "trust_remote_code": True, @@ -162,17 +178,21 @@ def load_model( if not is_trainable: model.requires_grad_(False) + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + model.eval() else: model.train() trainable_params, all_param = count_parameters(model) if is_trainable: - param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( trainable_params, all_param, 100 * trainable_params / all_param ) else: - param_stats = "all params: {:d}".format(all_param) + param_stats = "all params: {:,}".format(all_param) logger.info(param_stats) diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index b52ddc86..4bed7e21 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -1,7 +1,22 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + from ...extras.logging import get_logger -from ...extras.packages import is_flash_attn2_available, is_sdpa_available if TYPE_CHECKING: @@ -13,21 +28,33 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: +def configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> None: + if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention + if model_args.flash_attn == "auto": + logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + elif model_args.flash_attn != "disabled": + logger.warning( + "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. " + "Will proceed at your own risk.".format(model_args.flash_attn) + ) + if model_args.flash_attn == "auto": return - elif model_args.flash_attn == "off": + elif model_args.flash_attn == "disabled": requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa": - if not is_sdpa_available(): + if not is_torch_sdpa_available(): logger.warning("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" elif model_args.flash_attn == "fa2": - if not is_flash_attn2_available(): + if not is_flash_attn_2_available(): logger.warning("FlashAttention-2 is not installed.") return diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index e0657be8..f4f3d8a5 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -1,3 +1,21 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and PEFT library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from functools import partial from types import MethodType @@ -60,15 +78,12 @@ def _fp32_forward_post_hook( return output.to(torch.float32) -def prepare_model_for_training( - model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" -) -> None: +def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: r""" Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) add the upcasting of the lm_head in fp32 - Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 """ if model_args.upcast_layernorm: logger.info("Upcasting layernorm weights in float32.") @@ -87,8 +102,8 @@ def prepare_model_for_training( setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") - if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: - logger.info("Upcasting lm_head outputs in float32.") - output_layer = getattr(model, output_layer_name) + if model_args.upcast_lmhead_output: + output_layer = model.get_output_embeddings() if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + logger.info("Upcasting lm_head outputs in float32.") output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index 3d9278e3..3ff79828 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from contextlib import nullcontext from typing import TYPE_CHECKING diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index c8dc52f5..af30bd50 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -1,3 +1,22 @@ +# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. +# +# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# This code is also inspired by the original LongLoRA implementation. +# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import TYPE_CHECKING, Optional, Tuple @@ -96,7 +115,8 @@ def llama_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -181,11 +201,9 @@ def llama_flash_attention_2_forward( query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - else: - groupsz = q_len attn_output: torch.Tensor = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back @@ -194,7 +212,8 @@ def llama_flash_attention_2_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -293,7 +312,8 @@ def llama_sdpa_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -303,7 +323,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2") + require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 4851bd29..a2812228 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, List from ...extras.logging import get_logger diff --git a/src/llamafactory/model/model_utils/mod.py b/src/llamafactory/model/model_utils/mod.py index 5708a1a8..ec73af00 100644 --- a/src/llamafactory/model/model_utils/mod.py +++ b/src/llamafactory/model/model_utils/mod.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING from ...extras.constants import MOD_SUPPORTED_MODELS diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index e554e45a..5c7473aa 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -1,5 +1,20 @@ -from typing import TYPE_CHECKING +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Sequence + +import torch from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version @@ -10,6 +25,13 @@ if TYPE_CHECKING: from ...hparams import ModelArguments +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, leaf_modules) + + def add_z3_leaf_module(model: "PreTrainedModel") -> None: r""" Sets module as a leaf module to skip partitioning in deepspeed zero3. @@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: if not is_deepspeed_zero3_enabled(): return - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") - from deepspeed.utils import set_z3_leaf_modules # type: ignore - if getattr(model.config, "model_type", None) == "dbrx": from transformers.models.dbrx.modeling_dbrx import DbrxFFN - set_z3_leaf_modules(model, [DbrxFFN]) + _set_z3_leaf_modules(model, [DbrxFFN]) if getattr(model.config, "model_type", None) == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock - set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) if getattr(model.config, "model_type", None) == "jetmoe": from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE - set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) if getattr(model.config, "model_type", None) == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) if getattr(model.config, "model_type", None) == "qwen2moe": from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 02a54f07..317646e0 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -1,3 +1,21 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and Optimum library. +# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py +# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import random from enum import Enum, unique @@ -5,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List import torch from datasets import load_dataset -from transformers import BitsAndBytesConfig, GPTQConfig +from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from transformers.utils.versions import require_version @@ -39,10 +57,9 @@ class QuantizationMethod(str, Enum): HQQ = "hqq" -def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: r""" - Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 - TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. """ if os.path.isfile(model_args.export_quantization_dataset): data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) @@ -51,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod data_path = model_args.export_quantization_dataset data_files = None - dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) - maxlen = model_args.export_quantization_maxlen + dataset = load_dataset( + path=data_path, + data_files=data_files, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + ) samples = [] + maxlen = model_args.export_quantization_maxlen for _ in range(model_args.export_quantization_nsamples): + n_try = 0 while True: + if n_try > 100: + raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") + sample_idx = random.randint(0, len(dataset) - 1) - sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") - if sample["input_ids"].size(1) >= maxlen: + sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + n_try += 1 + if sample["input_ids"].size(1) > maxlen: break # TODO: fix large maxlen word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] - samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] + samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()}) return samples @@ -76,14 +105,14 @@ def configure_quantization( init_kwargs: Dict[str, Any], ) -> None: r""" - Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) """ if getattr(config, "quantization_config", None): # ptq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") + if model_args.quantization_bit is not None: + logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") - if model_args.quantization_device_map != "auto": - init_kwargs["device_map"] = {"": get_current_device()} + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") @@ -105,46 +134,72 @@ def configure_quantization( logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) elif model_args.export_quantization_bit is not None: # auto-gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + if model_args.export_quantization_bit not in [8, 4, 3, 2]: + raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") + + require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") from accelerate.utils import get_max_memory if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") + raise ValueError("ChatGLM model is not supported yet.") init_kwargs["quantization_config"] = GPTQConfig( bits=model_args.export_quantization_bit, - tokenizer=tokenizer, dataset=_get_quantization_dataset(tokenizer, model_args), ) init_kwargs["device_map"] = "auto" init_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) - elif model_args.quantization_bit is not None: # bnb - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit is not None: # on-the-fly + if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora + ) + else: + raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type, - bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora - ) + # Do not assign device map if: + # 1. deepspeed zero3 or fsdp (train) + # 2. auto quantization device map (inference) + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") - if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": - if model_args.quantization_bit != 4: - raise ValueError("Only 4-bit quantized model can use auto device map.") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") - init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype - else: - init_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.HQQ.value: + if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: + raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("hqq", "To fix: pip install hqq") + init_kwargs["quantization_config"] = HqqConfig( + nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 + ) # use ATEN kernel (axis=0) for performance + logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.EETQ.value: + if model_args.quantization_bit != 8: + raise ValueError("EETQ only accepts 8-bit quantization.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("eetq", "To fix: pip install eetq") + init_kwargs["quantization_config"] = EetqConfig() + logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index 93ab8929..4373ee19 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -1,3 +1,21 @@ +# Copyright 2024 LMSYS and the LlamaFactory team. +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# This code is inspired by the LMSYS's FastChat library. +# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import TYPE_CHECKING @@ -21,8 +39,8 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ logger.warning("Current model does not support RoPE scaling.") return - if is_trainable: - if model_args.rope_scaling == "dynamic": + if model_args.model_max_length is not None: + if is_trainable and model_args.rope_scaling == "dynamic": logger.warning( "Dynamic NTK scaling may not work well with fine-tuning. " "See: https://github.com/huggingface/transformers/pull/24653" diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 8a16409d..9cfaec61 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, Optional from ...extras.logging import get_logger diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py index 64333688..9ab3d45a 100644 --- a/src/llamafactory/model/model_utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict import torch diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index c8260b7f..700bf470 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Tuple import torch diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 47591de6..f1831ced 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from types import MethodType from typing import TYPE_CHECKING, Any, Dict @@ -46,13 +60,16 @@ def patch_config( is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 - model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + if model_args.infer_dtype != "auto" and not is_trainable: + model_args.compute_dtype = getattr(torch, model_args.infer_dtype) + else: + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) if is_torch_npu_available(): use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] torch.npu.set_compile_mode(jit_compile=use_jit_compile) - configure_attn_implementation(config, model_args) + configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) @@ -74,14 +91,17 @@ def patch_config( # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) - if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp + # cast data type of the model if: + # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) + # 2. quantization_bit is not None (qlora) + if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: init_kwargs["torch_dtype"] = model_args.compute_dtype if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True if "device_map" not in init_kwargs and model_args.device_map: init_kwargs["device_map"] = model_args.device_map - if init_kwargs["device_map"] == "auto": + if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder if finetune_args.stage == "sft" and data_args.efficient_packing: @@ -137,6 +157,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_input_embeddings() + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: if isinstance(self.pretrained_model, PeftModel): self.pretrained_model.create_or_update_model_card(output_dir) @@ -145,4 +169,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) diff --git a/src/llamafactory/extras/callbacks.py b/src/llamafactory/train/callbacks.py similarity index 56% rename from src/llamafactory/extras/callbacks.py rename to src/llamafactory/train/callbacks.py index 441ebbfd..4d024278 100644 --- a/src/llamafactory/extras/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import logging import os @@ -8,22 +22,78 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, Optional +import torch import transformers -from transformers import TrainerCallback +from peft import PeftModel +from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length +from transformers.utils import ( + SAFE_WEIGHTS_NAME, + WEIGHTS_NAME, + is_safetensors_available, +) -from .constants import TRAINER_LOG -from .logging import LoggerHandler, get_logger -from .misc import fix_valuehead_checkpoint +from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import LoggerHandler, get_logger +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import save_file + if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments + from trl import AutoModelForCausalLMWithValueHead logger = get_logger(__name__) +def fix_valuehead_checkpoint( + model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool +) -> None: + r""" + The model is already unwrapped. + + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "")] = param + + os.remove(path_to_checkpoint) + model.pretrained_model.save_pretrained( + output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + logger.info("Value head model saved at: {}".format(output_dir)) + + class FixValueHeadModelCallback(TrainerCallback): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" @@ -37,8 +107,70 @@ class FixValueHeadModelCallback(TrainerCallback): ) +class SaveProcessorCallback(TrainerCallback): + def __init__(self, processor: "ProcessorMixin") -> None: + r""" + Initializes a callback for saving the processor. + """ + self.processor = processor + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + getattr(self.processor, "image_processor").save_pretrained(args.output_dir) + + +class PissaConvertCallback(TrainerCallback): + r""" + Initializes a callback for converting the PiSSA adapter to a normal one. + """ + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir)) + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") + pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") + logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) + # 1. save a pissa backup with init_lora_weights: True + # 2. save a converted lora with init_lora_weights: pissa + # 3. load the pissa backup with init_lora_weights: True + # 4. delete the initial adapter and change init_lora_weights to pissa + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + model.save_pretrained( + pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir + ) + model.load_adapter(pissa_backup_dir, "default", is_trainable=True) + model.set_adapter("default") + model.delete_adapter("pissa_init") + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + class LogCallback(TrainerCallback): - def __init__(self, output_dir: str) -> None: + def __init__(self) -> None: r""" Initializes a callback for logging training and evaluation status. """ @@ -56,7 +188,7 @@ class LogCallback(TrainerCallback): self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] if self.webui_mode: signal.signal(signal.SIGABRT, self._set_abort) - self.logger_handler = LoggerHandler(output_dir) + self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.root.addHandler(self.logger_handler) transformers.logging.add_handler(self.logger_handler) diff --git a/src/llamafactory/train/dpo/__init__.py b/src/llamafactory/train/dpo/__init__.py index 43fe9420..9ce0d089 100644 --- a/src/llamafactory/train/dpo/__init__.py +++ b/src/llamafactory/train/dpo/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_dpo diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index d860b29a..e45467d6 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -1,3 +1,21 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType @@ -10,7 +28,8 @@ from trl import DPOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context +from ..callbacks import PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -35,7 +54,6 @@ class CustomDPOTrainer(DPOTrainer): disable_dropout_in_model(ref_model) self.finetuning_args = finetuning_args - self.processor = processor self.reference_free = False self.use_dpo_data_collator = True # hack to avoid warning self.generate_during_eval = False # disable at evaluation @@ -61,6 +79,8 @@ class CustomDPOTrainer(DPOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -71,10 +91,17 @@ class CustomDPOTrainer(DPOTrainer): self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model.eval() - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if finetuning_args.pissa_convert: + self.callback_handler.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -87,12 +114,6 @@ class CustomDPOTrainer(DPOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) - def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r""" Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. @@ -176,7 +197,7 @@ class CustomDPOTrainer(DPOTrainer): if self.ref_model is None: ref_model = model - ref_context = get_ref_context(self.accelerator, model) + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 992985b0..431b5285 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional diff --git a/src/llamafactory/train/kto/__init__.py b/src/llamafactory/train/kto/__init__.py index 34c7905a..a1900368 100644 --- a/src/llamafactory/train/kto/__init__.py +++ b/src/llamafactory/train/kto/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_kto diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 22a84e4a..460311e4 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -1,3 +1,21 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType @@ -9,7 +27,8 @@ from trl import KTOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -35,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer): disable_dropout_in_model(ref_model) self.finetuning_args = finetuning_args - self.processor = processor self.reference_free = False self.use_dpo_data_collator = True # hack to avoid warning self.generate_during_eval = False # disable at evaluation @@ -60,6 +78,8 @@ class CustomKTOTrainer(KTOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -70,10 +90,14 @@ class CustomKTOTrainer(KTOTrainer): self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model.eval() - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -92,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer): """ return Trainer._get_train_sampler(self) - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) - def forward( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" ) -> Tuple["torch.Tensor", "torch.Tensor"]: @@ -143,7 +161,7 @@ class CustomKTOTrainer(KTOTrainer): """ if self.ref_model is None: ref_model = model - ref_context = get_ref_context(self.accelerator, model) + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index c79b160b..8182a184 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, List, Optional from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset diff --git a/src/llamafactory/train/ppo/__init__.py b/src/llamafactory/train/ppo/__init__.py index d17336d5..161f6f5d 100644 --- a/src/llamafactory/train/ppo/__init__.py +++ b/src/llamafactory/train/ppo/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_ppo diff --git a/src/llamafactory/train/ppo/ppo_utils.py b/src/llamafactory/train/ppo/ppo_utils.py index fec3fc1e..05c40946 100644 --- a/src/llamafactory/train/ppo/ppo_utils.py +++ b/src/llamafactory/train/ppo/ppo_utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from contextlib import nullcontext from typing import TYPE_CHECKING, Dict, List, Literal, Optional diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 2e1288e4..57f0b848 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -1,6 +1,24 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import os import sys +import warnings from types import MethodType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -9,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs from tqdm import tqdm from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers.optimization import get_scheduler +from transformers.trainer_callback import CallbackHandler from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME @@ -16,9 +35,9 @@ from trl import PPOConfig, PPOTrainer from trl.core import PPODecorators, logprobs_from_logits from trl.models.utils import unwrap_model_for_generation -from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor +from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimzer, create_custom_scheduler from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -81,10 +100,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) # Add deepspeed config - ppo_config.accelerator_kwargs["kwargs_handlers"] = [ - DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) - ] if training_args.deepspeed_plugin is not None: + ppo_config.accelerator_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) + ] ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin # Create optimizer and scheduler @@ -113,7 +132,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.finetuning_args = finetuning_args self.reward_model = reward_model self.current_device = get_current_device() # patch for deepspeed training - self.processor = processor self.generation_config = GenerationConfig( pad_token_id=self.tokenizer.pad_token_id, @@ -125,8 +143,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.control = TrainerControl() self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - self.log_callback, self.save_callback = callbacks[0], callbacks[1] - assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) + self.callback_handler = CallbackHandler( + [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler + ) if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") @@ -134,8 +153,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm" - device_type = unwrapped_model.pretrained_model.device.type - self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype) + self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype) + warnings.simplefilter("ignore") # remove gc warnings on ref model if finetuning_args.reward_model_type == "full": if self.is_deepspeed_enabled: @@ -147,10 +166,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer): else: self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor + self.add_callback(FixValueHeadModelCallback) - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" @@ -184,23 +209,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.is_world_process_zero(): logger.info("***** Running training *****") - logger.info(" Num examples = {}".format(num_examples)) - logger.info(" Num Epochs = {}".format(num_train_epochs)) - logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size)) + logger.info(" Num examples = {:,}".format(num_examples)) + logger.info(" Num Epochs = {:,}".format(num_train_epochs)) + logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size)) logger.info( - " Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format( + " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( total_train_batch_size ) ) - logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps)) - logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs)) - logger.info(" Total training steps = {}".format(max_steps)) - logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0])) + logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps)) + logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs)) + logger.info(" Total training steps = {:,}".format(max_steps)) + logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0])) dataiter = iter(self.dataloader) loss_meter = AverageMeter() reward_meter = AverageMeter() - self.log_callback.on_train_begin(self.args, self.state, self.control) + self.callback_handler.on_train_begin(self.args, self.state, self.control) for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): try: @@ -238,7 +263,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.warning("Failed to save stats due to unknown errors.") self.state.global_step += 1 - self.log_callback.on_step_end(self.args, self.state, self.control) + self.callback_handler.on_step_end(self.args, self.state, self.control) if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0: logs = dict( @@ -250,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): tqdm.write(str(logs)) logs["step"] = step self.state.log_history.append(logs) - self.log_callback.on_log(self.args, self.state, self.control) + self.callback_handler.on_log(self.args, self.state, self.control, logs) loss_meter.reset() reward_meter.reset() @@ -258,17 +283,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.save_model( os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) ) - self.save_callback.on_save( - self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) - ) + self.callback_handler.on_save(self.args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: break - self.log_callback.on_train_end(self.args, self.state, self.control) - self.save_callback.on_train_end( - self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) - ) + self.callback_handler.on_train_end(self.args, self.state, self.control) def create_optimizer( self, @@ -486,7 +506,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer): elif self.args.should_save: self._save(output_dir) - - if self.processor is not None and self.args.should_save: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 111704c6..651296f3 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -1,14 +1,28 @@ -# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional from transformers import DataCollatorWithPadding from ...data import get_dataset -from ...extras.callbacks import FixValueHeadModelCallback -from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer +from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint from ..trainer_utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -60,6 +74,7 @@ def run_ppo( ppo_trainer.save_model() if training_args.should_save: fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) + ppo_trainer.save_state() # must be called after save_model to have a folder if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/src/llamafactory/train/pt/__init__.py b/src/llamafactory/train/pt/__init__.py index bdf397f6..d80e6f22 100644 --- a/src/llamafactory/train/pt/__init__.py +++ b/src/llamafactory/train/pt/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_pt diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1d96e82f..e8f180a6 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -1,9 +1,24 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from types import MethodType -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional from transformers import Trainer from ...extras.logging import get_logger +from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimzer, create_custom_scheduler @@ -27,11 +42,18 @@ class CustomTrainer(Trainer): ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args - self.processor = processor - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -43,9 +65,3 @@ class CustomTrainer(Trainer): ) -> "torch.optim.lr_scheduler.LRScheduler": create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 8a635567..b84a0e7d 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from typing import TYPE_CHECKING, List, Optional diff --git a/src/llamafactory/train/rm/__init__.py b/src/llamafactory/train/rm/__init__.py index dedac35f..48278315 100644 --- a/src/llamafactory/train/rm/__init__.py +++ b/src/llamafactory/train/rm/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_rm diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index 99dc6ab8..fb880b1c 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, Sequence, Tuple, Union import numpy as np diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index bfb344dc..accc877d 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -1,3 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the CarperAI's trlx library. +# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2022 CarperAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + import json import os from types import MethodType @@ -7,6 +46,7 @@ import torch from transformers import Trainer from ...extras.logging import get_logger +from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimzer, create_custom_scheduler @@ -30,12 +70,20 @@ class PairwiseTrainer(Trainer): ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args - self.processor = processor self.can_return_loss = True # override property to return eval_loss - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor + self.add_callback(FixValueHeadModelCallback) - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -48,12 +96,6 @@ class PairwiseTrainer(Trainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) - def compute_loss( self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: @@ -63,7 +105,7 @@ class PairwiseTrainer(Trainer): Subclass and override to inject custom behavior. Note that the first element will be removed from the output tuple. - See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777 + See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842 """ # Compute rewards _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) @@ -79,7 +121,6 @@ class PairwiseTrainer(Trainer): chosen_scores, rejected_scores = [], [] # Compute pairwise loss. Only backprop on the different tokens before padding - # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py loss = 0 for i in range(batch_size): chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 @@ -125,4 +166,5 @@ class PairwiseTrainer(Trainer): res: List[str] = [] for c_score, r_score in zip(chosen_scores, rejected_scores): res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) + writer.write("\n".join(res)) diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 2e9e194b..e0b32b77 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -1,12 +1,48 @@ -# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the CarperAI's trlx library. +# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2022 CarperAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. from typing import TYPE_CHECKING, List, Optional from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset -from ...extras.callbacks import FixValueHeadModelCallback -from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer +from ..callbacks import fix_valuehead_checkpoint from ..trainer_utils import create_modelcard_and_push from .metric import compute_accuracy from .trainer import PairwiseTrainer @@ -40,7 +76,7 @@ def run_rm( args=training_args, finetuning_args=finetuning_args, data_collator=data_collator, - callbacks=callbacks + [FixValueHeadModelCallback()], + callbacks=callbacks, compute_metrics=compute_accuracy, **tokenizer_module, **split_dataset(dataset, data_args, training_args), @@ -52,6 +88,7 @@ def run_rm( trainer.save_model() if training_args.should_save: fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors) + trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() diff --git a/src/llamafactory/train/sft/__init__.py b/src/llamafactory/train/sft/__init__.py index f2f84e78..475dfe5f 100644 --- a/src/llamafactory/train/sft/__init__.py +++ b/src/llamafactory/train/sft/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_sft diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index b135fcfb..c69608c0 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -1,14 +1,35 @@ +# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict import numpy as np +import torch +from transformers import EvalPrediction +from transformers.utils import is_jieba_available, is_nltk_available from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available +from ...extras.packages import is_rouge_available if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer if is_jieba_available(): @@ -23,6 +44,22 @@ if is_rouge_available(): from rouge_chinese import Rouge +def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: + preds, labels = eval_preds.predictions, eval_preds.label_ids + accuracies = [] + for i in range(len(preds)): + pred, label = preds[i, :-1], labels[i, 1:] + label_mask = label != IGNORE_INDEX + accuracies.append(np.mean(pred[label_mask] == label[label_mask])) + + return {"accuracy": float(np.mean(accuracies))} + + +def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": + logits = logits[0] if isinstance(logits, (list, tuple)) else logits + return torch.argmax(logits, dim=-1) + + @dataclass class ComputeMetrics: r""" @@ -31,11 +68,11 @@ class ComputeMetrics: tokenizer: "PreTrainedTokenizer" - def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]: r""" Uses the model predictions to compute metrics. """ - preds, labels = eval_preds + preds, labels = eval_preds.predictions, eval_preds.label_ids score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index c063b214..954bb69f 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from types import MethodType @@ -9,10 +26,12 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger +from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from torch.utils.data import Dataset from transformers import ProcessorMixin from transformers.trainer import PredictionOutput @@ -32,11 +51,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args - self.processor = processor - if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -49,12 +75,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) - def prediction_step( self, model: "torch.nn.Module", @@ -94,7 +114,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding return padded_tensor.contiguous() # in contiguous memory - def save_predictions(self, predict_results: "PredictionOutput") -> None: + def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: r""" Saves model predictions to `output_dir`. @@ -115,18 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): for i in range(len(preds)): pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] - if len(pad_len): - preds[i] = np.concatenate( - (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1 - ) # move pad token to last + if len(pad_len): # move pad token to last + preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) - decoded_labels = self.tokenizer.batch_decode( - labels, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for label, pred in zip(decoded_labels, decoded_preds): - res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): + res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f1e000bd..c12a70aa 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional @@ -10,7 +25,7 @@ from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push -from .metric import ComputeMetrics +from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor from .trainer import CustomSeq2SeqTrainer if TYPE_CHECKING: @@ -56,7 +71,8 @@ def run_sft( finetuning_args=finetuning_args, data_collator=data_collator, callbacks=callbacks, - compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy, + preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor, **tokenizer_module, **split_dataset(dataset, data_args, training_args), ) @@ -75,7 +91,7 @@ def run_sft( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) # Evaluation if training_args.do_eval: @@ -92,7 +108,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results) + trainer.save_predictions(dataset, predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 0ddcdb11..4b581691 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -1,8 +1,27 @@ -from contextlib import contextmanager +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore +# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus +# and the original BAdam's implementation: https://github.com/Ledzy/BAdam +# and the HuggingFace's TRL library: https://github.com/huggingface/trl +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer +from transformers.integrations import is_deepspeed_zero3_enabled from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names @@ -19,7 +38,6 @@ if is_galore_available(): if TYPE_CHECKING: - from accelerate import Accelerator from transformers import PreTrainedModel, Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead @@ -83,15 +101,12 @@ def create_ref_model( The valuehead parameter is randomly initialized since it is useless for PPO training. """ if finetuning_args.ref_model is not None: - ref_model_args_dict = model_args.to_dict() - ref_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.ref_model, - adapter_name_or_path=finetuning_args.ref_model_adapters, - quantization_bit=finetuning_args.ref_model_quantization_bit, - ) + ref_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.ref_model, + adapter_name_or_path=finetuning_args.ref_model_adapters, + quantization_bit=finetuning_args.ref_model_quantization_bit, ) - ref_model_args = ModelArguments(**ref_model_args_dict) ref_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( @@ -102,9 +117,11 @@ def create_ref_model( if finetuning_args.finetuning_type == "lora": ref_model = None else: - tokenizer = load_tokenizer(model_args)["tokenizer"] + ref_model_args = ModelArguments.copyfrom(model_args) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( - tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ) logger.info("Created reference model from the model itself.") @@ -139,15 +156,12 @@ def create_reward_model( logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) return None else: - reward_model_args_dict = model_args.to_dict() - reward_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.reward_model, - adapter_name_or_path=finetuning_args.reward_model_adapters, - quantization_bit=finetuning_args.reward_model_quantization_bit, - ) + reward_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.reward_model, + adapter_name_or_path=finetuning_args.reward_model_adapters, + quantization_bit=finetuning_args.reward_model_quantization_bit, ) - reward_model_args = ModelArguments(**reward_model_args_dict) reward_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(reward_model_args)["tokenizer"] reward_model = load_model( @@ -158,17 +172,6 @@ def create_reward_model( return reward_model -@contextmanager -def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"): - r""" - Gets adapter context for the reference model. - """ - with accelerator.unwrap_model(model).disable_adapter(): - model.eval() - yield - model.train() - - def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: r""" Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) @@ -184,7 +187,7 @@ def _create_galore_optimizer( finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": - galore_targets = find_all_linear_modules(model) + galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) else: galore_targets = finetuning_args.galore_target @@ -334,6 +337,7 @@ def _create_badam_optimizer( start_block=finetuning_args.badam_start_block, switch_mode=finetuning_args.badam_switch_mode, verbose=finetuning_args.badam_verbose, + ds_zero3_enabled=is_deepspeed_zero3_enabled(), ) logger.info( f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " @@ -355,7 +359,7 @@ def _create_badam_optimizer( **optim_kwargs, ) logger.info( - f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"mask mode is {finetuning_args.badam_mask_mode}" ) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index eed875e9..dc982e07 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -1,13 +1,30 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer -from ..extras.callbacks import LogCallback +from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.logging import get_logger from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer +from .callbacks import LogCallback from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo @@ -24,8 +41,8 @@ logger = get_logger(__name__) def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) - callbacks.append(LogCallback(training_args.output_dir)) if finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) @@ -84,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: safe_serialization=(not model_args.export_legacy_format), ) + if finetuning_args.stage == "rm": + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index c82710d3..8abef920 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple @@ -9,7 +23,7 @@ from ..data import Role from ..extras.constants import PEFT_METHODS from ..extras.misc import torch_gc from ..extras.packages import is_gradio_available -from .common import get_save_dir +from .common import QUANTIZATION_BITS, get_save_dir from .locales import ALERTS @@ -62,17 +76,24 @@ class WebChatModel(ChatModel): yield error return + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_path, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), visual_inputs=get("top.visual_inputs"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), + infer_dtype=get("infer.infer_dtype"), ) if checkpoint_path: @@ -126,16 +147,15 @@ class WebChatModel(ChatModel): ): response += new_text if tools: - result = self.engine.template.format_tools.extract(response) + result = self.engine.template.extract_tool(response) else: result = response - if isinstance(result, tuple): - name, arguments = result - arguments = json.loads(arguments) - tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) - output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}] - bot_text = "```json\n" + tool_call + "\n```" + if isinstance(result, list): + tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result] + tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) + output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] + bot_text = "```json\n" + tool_calls + "\n```" else: output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] bot_text = result diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 37b38df0..bced18f0 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from collections import defaultdict @@ -33,13 +47,19 @@ DEFAULT_CONFIG_DIR = "config" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user_config.yaml" +QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"] +GPTQ_BITS = ["8", "4", "3", "2"] def get_save_dir(*paths: str) -> os.PathLike: r""" Gets the path to saved model checkpoints. """ - paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths) + if os.path.sep in paths[-1]: + logger.warning("Found complex path, some features may be not available.") + return paths[-1] + + paths = (path.replace(" ", "").strip() for path in paths) return os.path.join(DEFAULT_SAVE_DIR, *paths) diff --git a/src/llamafactory/webui/components/__init__.py b/src/llamafactory/webui/components/__init__.py index 5c1e21b8..715fb6e4 100644 --- a/src/llamafactory/webui/components/__init__.py +++ b/src/llamafactory/webui/components/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .chatbot import create_chat_box from .eval import create_eval_tab from .export import create_export_tab diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index f83694b1..ad74114b 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict, Tuple from ...data import Role diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py index 232b973d..88e500cf 100644 --- a/src/llamafactory/webui/components/data.py +++ b/src/llamafactory/webui/components/data.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import TYPE_CHECKING, Any, Dict, List, Tuple diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 0a7a0f44..b522913e 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 7e1493c8..0a938f02 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -1,10 +1,24 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict, Generator, List, Union from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc from ...extras.packages import is_gradio_available from ...train.tuner import export_model -from ..common import get_save_dir +from ..common import GPTQ_BITS, get_save_dir from ..locales import ALERTS @@ -18,7 +32,11 @@ if TYPE_CHECKING: from ..engine import Engine -GPTQ_BITS = ["8", "4", "3", "2"] +def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": + if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) def save_model( @@ -96,6 +114,9 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: export_dir = gr.Textbox() export_hub_model_id = gr.Textbox() + checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") + checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) + export_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 970f4629..a0064479 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available @@ -18,15 +32,26 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: input_elems = engine.manager.get_base_elems() elem_dict = dict() - infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + with gr.Row(): + infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") + with gr.Row(): load_btn = gr.Button() unload_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) - input_elems.update({infer_backend}) - elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) + input_elems.update({infer_backend, infer_dtype}) + elem_dict.update( + dict( + infer_backend=infer_backend, + infer_dtype=infer_dtype, + load_btn=load_btn, + unload_btn=unload_btn, + info_box=info_box, + ) + ) chatbot, messages, chat_elems = create_chat_box(engine, visible=False) elem_dict.update(chat_elems) diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index fd0ead3d..9df3f062 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -1,10 +1,24 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available from ..common import get_model_info, list_checkpoints, save_config -from ..utils import can_quantize +from ..utils import can_quantize, can_quantize_to if is_gradio_available(): @@ -29,17 +43,23 @@ def create_top() -> Dict[str, "Component"]: with gr.Accordion(open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) - template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) - rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) - booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1) + quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2) + booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2) visual_inputs = gr.Checkbox(scale=1) - model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False) + model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) model_name.input(save_config, inputs=[lang, model_name], queue=False) model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) - finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False) + finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( + list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False + ) checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) + quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) return dict( lang=lang, @@ -49,6 +69,7 @@ def create_top() -> Dict[str, "Component"]: checkpoint_path=checkpoint_path, advanced_tab=advanced_tab, quantization_bit=quantization_bit, + quantization_method=quantization_method, template=template, rope_scaling=rope_scaling, booster=booster, diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index dccc8500..4636050b 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType @@ -40,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: num_train_epochs = gr.Textbox(value="3.0") max_grad_norm = gr.Textbox(value="1.0") max_samples = gr.Textbox(value="100000") - compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16") + compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16") input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type}) elem_dict.update( @@ -152,10 +166,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter = gr.Checkbox() with gr.Row(): - with gr.Column(scale=1): - use_rslora = gr.Checkbox() - use_dora = gr.Checkbox() - + use_rslora = gr.Checkbox() + use_dora = gr.Checkbox() + use_pissa = gr.Checkbox() lora_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2) @@ -168,6 +181,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter, use_rslora, use_dora, + use_pissa, lora_target, additional_target, } @@ -182,6 +196,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter=create_new_adapter, use_rslora=use_rslora, use_dora=use_dora, + use_pissa=use_pissa, lora_target=lora_target, additional_target=additional_target, ) @@ -279,7 +294,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(scale=1): loss_viewer = gr.Plot() - input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload}) + input_elems.update({output_dir, config_path, ds_stage, ds_offload}) elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, diff --git a/src/llamafactory/webui/css.py b/src/llamafactory/webui/css.py index 36e3d4c2..53982119 100644 --- a/src/llamafactory/webui/css.py +++ b/src/llamafactory/webui/css.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + CSS = r""" .duplicate-button { margin: auto !important; diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index eb6142d3..04893215 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict from .chatter import WebChatModel diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index bae3ba76..d25f4d38 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from ..extras.packages import is_gradio_available diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 05cf3bed..852b1b3c 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + LOCALES = { "lang": { "en": { @@ -71,15 +85,29 @@ LOCALES = { "quantization_bit": { "en": { "label": "Quantization bit", - "info": "Enable 4/8-bit model quantization (QLoRA).", + "info": "Enable quantization (QLoRA).", }, "ru": { "label": "Уровень квантования", - "info": "Включить 4/8-битное квантование модели (QLoRA).", + "info": "Включить квантование (QLoRA).", }, "zh": { "label": "量化等级", - "info": "启用 4/8 比特模型量化(QLoRA)。", + "info": "启用量化(QLoRA)。", + }, + }, + "quantization_method": { + "en": { + "label": "Quantization method", + "info": "Quantization algorithm to use.", + }, + "ru": { + "label": "Метод квантования", + "info": "Алгоритм квантования, который следует использовать.", + }, + "zh": { + "label": "量化方法", + "info": "使用的量化算法。", }, }, "template": { @@ -732,6 +760,20 @@ LOCALES = { "info": "使用权重分解的 LoRA。", }, }, + "use_pissa": { + "en": { + "label": "Use PiSSA", + "info": "Use PiSSA method.", + }, + "ru": { + "label": "используйте PiSSA", + "info": "Используйте метод PiSSA.", + }, + "zh": { + "label": "使用 PiSSA", + "info": "使用 PiSSA 方法。", + }, + }, "lora_target": { "en": { "label": "LoRA modules (optional)", @@ -1192,6 +1234,17 @@ LOCALES = { "label": "推理引擎", }, }, + "infer_dtype": { + "en": { + "label": "Inference data type", + }, + "ru": { + "label": "Тип данных для вывода", + }, + "zh": { + "label": "推理数据类型", + }, + }, "load_btn": { "en": { "value": "Load model", diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index 326fdb8d..ebe9f1b9 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple @@ -57,6 +71,7 @@ class Manager: self._id_to_elem["top.finetuning_type"], self._id_to_elem["top.checkpoint_path"], self._id_to_elem["top.quantization_bit"], + self._id_to_elem["top.quantization_method"], self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.booster"], diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 852805da..ffec54e2 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from copy import deepcopy from subprocess import Popen, TimeoutExpired @@ -8,9 +22,9 @@ from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES -from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd +from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd if is_gradio_available(): @@ -38,7 +52,7 @@ class Runner: def set_abort(self) -> None: self.aborted = True if self.trainer is not None: - abort_leaf_process(self.trainer.pid) + abort_process(self.trainer.pid) def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] @@ -90,6 +104,11 @@ class Runner: model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + args = dict( stage=TRAINING_STAGES[get("train.training_stage")], do_train=True, @@ -97,7 +116,8 @@ class Runner: cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", @@ -160,6 +180,8 @@ class Runner: args["create_new_adapter"] = get("train.create_new_adapter") args["use_rslora"] = get("train.use_rslora") args["use_dora"] = get("train.use_dora") + args["pissa_init"] = get("train.use_pissa") + args["pissa_convert"] = get("train.use_pissa") args["lora_target"] = get("train.lora_target") or "all" args["additional_target"] = get("train.additional_target") or None @@ -219,13 +241,19 @@ class Runner: model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() + if get("top.quantization_bit") in QUANTIZATION_BITS: + quantization_bit = int(get("top.quantization_bit")) + else: + quantization_bit = None + args = dict( stage="sft", model_name_or_path=get("top.model_path"), cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, finetuning_type=finetuning_type, - quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, + quantization_bit=quantization_bit, + quantization_method=get("top.quantization_method"), template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", @@ -283,6 +311,7 @@ class Runner: env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" + env["LLAMABOARD_WORKDIR"] = args["output_dir"] if args.get("deepspeed", None) is not None: env["FORCE_TORCHRUN"] = "1" @@ -291,7 +320,7 @@ class Runner: def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: config_dict = {} - skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"] + skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] for elem, value in data.items(): elem_id = self.manager.get_id_by_elem(elem) if elem_id not in skip_ids: diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index e39f2aa4..6e5fdbe4 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os import signal @@ -11,6 +25,7 @@ from yaml import safe_dump, safe_load from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot +from ..model import QuantizationMethod from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir from .locales import ALERTS @@ -19,16 +34,19 @@ if is_gradio_available(): import gradio as gr -def abort_leaf_process(pid: int) -> None: +def abort_process(pid: int) -> None: r""" - Aborts the leaf processes. + Aborts the processes recursively in a bottom-up way. """ - children = psutil.Process(pid).children() - if children: - for child in children: - abort_leaf_process(child.pid) - else: + try: + children = psutil.Process(pid).children() + if children: + for child in children: + abort_process(child.pid) + os.kill(pid, signal.SIGABRT) + except Exception: + pass def can_quantize(finetuning_type: str) -> "gr.Dropdown": @@ -41,6 +59,20 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown": return gr.Dropdown(interactive=True) +def can_quantize_to(quantization_method: str) -> "gr.Dropdown": + r""" + Returns the available quantization bits. + """ + if quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + available_bits = ["none", "8", "4"] + elif quantization_method == QuantizationMethod.HQQ.value: + available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"] + elif quantization_method == QuantizationMethod.EETQ.value: + available_bits = ["none", "8"] + + return gr.Dropdown(choices=available_bits) + + def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: r""" Modifys states after changing the training stage. diff --git a/src/train.py b/src/train.py index b20aa9d2..6703ffdb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from llamafactory.train.tuner import run_exp diff --git a/src/webui.py b/src/webui.py index bbefb54e..99370af2 100644 --- a/src/webui.py +++ b/src/webui.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from llamafactory.webui.interface import create_ui diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py new file mode 100644 index 00000000..1845df24 --- /dev/null +++ b/tests/data/test_formatter.py @@ -0,0 +1,123 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter + + +def test_empty_formatter(): + formatter = EmptyFormatter(slots=["\n"]) + assert formatter.apply() == ["\n"] + + +def test_string_formatter(): + formatter = StringFormatter(slots=["", "Human: {{content}}\nAssistant:"]) + assert formatter.apply(content="Hi") == ["", "Human: Hi\nAssistant:"] + + +def test_function_formatter(): + formatter = FunctionFormatter(slots=[], tool_format="default") + tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) + assert formatter.apply(content=tool_calls) == [ + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" + ] + + +def test_multi_function_formatter(): + formatter = FunctionFormatter(slots=[], tool_format="default") + tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) + assert formatter.apply(content=tool_calls) == [ + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", + ] + + +def test_default_tool_formatter(): + formatter = ToolFormatter(tool_format="default") + tools = [ + { + "name": "test_tool", + "description": "tool_desc", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "foo_desc"}, + "bar": {"type": "number", "description": "bar_desc"}, + }, + "required": ["foo"], + }, + } + ] + assert formatter.apply(content=json.dumps(tools)) == [ + "You have access to the following tools:\n" + "> Tool Name: test_tool\n" + "Tool Description: tool_desc\n" + "Tool Args:\n" + " - foo (string, required): foo_desc\n" + " - bar (number): bar_desc\n\n" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [test_tool]).\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{"input": "hello world", "num_beams": 5}```).\n""" + "```\n" + ] + + +def test_default_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +def test_default_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = ( + """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +def test_glm4_tool_formatter(): + formatter = ToolFormatter(tool_format="glm4") + tools = [ + { + "name": "test_tool", + "description": "tool_desc", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "foo_desc"}, + "bar": {"type": "number", "description": "bar_desc"}, + }, + "required": ["foo"], + }, + } + ] + assert formatter.apply(content=json.dumps(tools)) == [ + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" + "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4)) + ] + + +def test_glm4_tool_extractor(): + formatter = ToolFormatter(tool_format="glm4") + result = """test_tool\n{"foo": "bar", "size": 10}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] diff --git a/tests/data/test_processor.py b/tests/data/test_processor.py new file mode 100644 index 00000000..fa8f7172 --- /dev/null +++ b/tests/data/test_processor.py @@ -0,0 +1,32 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import pytest + +from llamafactory.data.processors.processor_utils import infer_seqlen + + +@pytest.mark.parametrize( + "test_input,test_output", + [ + ((3000, 2000, 1000), (600, 400)), + ((2000, 3000, 1000), (400, 600)), + ((1000, 100, 1000), (900, 100)), + ((100, 1000, 1000), (100, 900)), + ], +) +def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]): + assert test_output == infer_seqlen(*test_input) diff --git a/tests/data/test_supervised.py b/tests/data/test_supervised.py index bb7f71df..9cb49615 100644 --- a/tests/data/test_supervised.py +++ b/tests/data/test_supervised.py @@ -1,24 +1,40 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +import random import pytest from datasets import load_dataset +from transformers import AutoTokenizer from llamafactory.data import get_dataset from llamafactory.hparams import get_train_args from llamafactory.model import load_tokenizer -TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM") +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") -TRAINING_ARGS = { +TRAIN_ARGS = { "model_name_or_path": TINY_LLAMA, "stage": "sft", "do_train": True, "finetuning_type": "full", - "dataset": "llamafactory/tiny_dataset", + "dataset": "llamafactory/tiny-supervised-dataset", "dataset_dir": "ONLINE", "template": "llama3", - "cutoff_len": 1024, + "cutoff_len": 8192, "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, @@ -26,19 +42,26 @@ TRAINING_ARGS = { } -@pytest.mark.parametrize("test_num", [5]) -def test_supervised(test_num: int): - model_args, data_args, training_args, _, _ = get_train_args(TRAINING_ARGS) +@pytest.mark.parametrize("num_samples", [16]) +def test_supervised(num_samples: int): + model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) - original_data = load_dataset(TRAINING_ARGS["dataset"], split="train") - for test_idx in range(test_num): - decode_result = tokenizer.decode(tokenized_data["input_ids"][test_idx]) + ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + + original_data = load_dataset(TRAIN_ARGS["dataset"], split="train") + indexes = random.choices(range(len(original_data)), k=num_samples) + for index in indexes: + prompt = original_data[index]["instruction"] + if original_data[index]["input"]: + prompt += "\n" + original_data[index]["input"] + messages = [ - {"role": "user", "content": original_data[test_idx]["instruction"]}, - {"role": "assistant", "content": original_data[test_idx]["output"]}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": original_data[index]["output"]}, ] - templated_result = tokenizer.apply_chat_template(messages, tokenize=False) - assert decode_result == templated_result + templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False) + decoded_result = tokenizer.decode(tokenized_data["input_ids"][index]) + assert templated_result == decoded_result diff --git a/tests/data/test_template.py b/tests/data/test_template.py new file mode 100644 index 00000000..e4728a84 --- /dev/null +++ b/tests/data/test_template.py @@ -0,0 +1,80 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from transformers import AutoTokenizer + +from llamafactory.data import get_template_and_fix_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +MESSAGES = [ + {"role": "user", "content": "How are you"}, + {"role": "assistant", "content": "I am fine!"}, + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "很高兴认识你!"}, +] + + +def test_encode_oneturn(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + assert tokenizer.decode(prompt_ids) == ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>" + + +def test_encode_multiturn(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) + assert tokenizer.decode(encoded_pairs[0][0]) == ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>" + assert tokenizer.decode(encoded_pairs[1][0]) == ( + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>" + + +def test_jinja_template(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + get_template_and_fix_tokenizer(tokenizer, name="llama3") + assert tokenizer.chat_template != ref_tokenizer.chat_template + assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) + + +def test_qwen_template(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") + template = get_template_and_fix_tokenizer(tokenizer, name="qwen") + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + assert tokenizer.decode(prompt_ids) == ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n" + ) + assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>" diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py new file mode 100644 index 00000000..f85d9d57 --- /dev/null +++ b/tests/eval/test_eval_template.py @@ -0,0 +1,91 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.eval.template import get_eval_template + + +def test_eval_template_en(): + support_set = [ + { + "question": "Fewshot question", + "A": "Fewshot1", + "B": "Fewshot2", + "C": "Fewshot3", + "D": "Fewshot4", + "answer": "B", + } + ] + example = { + "question": "Target question", + "A": "Target1", + "B": "Target2", + "C": "Target3", + "D": "Target4", + "answer": "C", + } + template = get_eval_template(name="en") + messages = template.format_example(example, support_set=support_set, subject_name="SubName") + assert messages == [ + { + "role": "user", + "content": ( + "The following are multiple choice questions (with answers) about SubName.\n\n" + "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:", + }, + {"role": "assistant", "content": "C"}, + ] + + +def test_eval_template_zh(): + support_set = [ + { + "question": "示例问题", + "A": "示例答案1", + "B": "示例答案2", + "C": "示例答案3", + "D": "示例答案4", + "answer": "B", + } + ] + example = { + "question": "目标问题", + "A": "目标答案1", + "B": "目标答案2", + "C": "目标答案3", + "D": "目标答案4", + "answer": "C", + } + template = get_eval_template(name="zh") + messages = template.format_example(example, support_set=support_set, subject_name="主题") + assert messages == [ + { + "role": "user", + "content": ( + "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n" + "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:", + }, + {"role": "assistant", "content": "C"}, + ] diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py index 4d414289..4cae3d7c 100644 --- a/tests/model/model_utils/test_attention.py +++ b/tests/model/model_utils/test_attention.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available @@ -6,11 +20,16 @@ from llamafactory.hparams import get_infer_args from llamafactory.model import load_model, load_tokenizer -TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM") +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "template": "llama3", +} def test_attention(): - attention_available = ["off"] + attention_available = ["disabled"] if is_torch_sdpa_available(): attention_available.append("sdpa") @@ -18,18 +37,12 @@ def test_attention(): attention_available.append("fa2") llama_attention_classes = { - "off": "LlamaAttention", + "disabled": "LlamaAttention", "sdpa": "LlamaSdpaAttention", "fa2": "LlamaFlashAttention2", } for requested_attention in attention_available: - model_args, _, finetuning_args, _ = get_infer_args( - { - "model_name_or_path": TINY_LLAMA, - "template": "llama2", - "flash_attn": requested_attention, - } - ) + model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS}) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args) for module in model.modules(): diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py new file mode 100644 index 00000000..9b6dfc9e --- /dev/null +++ b/tests/model/model_utils/test_checkpointing.py @@ -0,0 +1,74 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "lora_target": "all", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + + +def test_checkpointing_enable(): + model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): + assert getattr(module, "gradient_checkpointing") is True + + +def test_checkpointing_disable(): + model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): + assert getattr(module, "gradient_checkpointing") is False + + +def test_upcast_layernorm(): + model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for name, param in model.named_parameters(): + if param.ndim == 1 and "norm" in name: + assert param.dtype == torch.float32 + + +def test_upcast_lmhead_output(): + model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) + outputs: "torch.Tensor" = model.get_output_embeddings()(inputs) + assert outputs.dtype == torch.float32 diff --git a/tests/model/test_base.py b/tests/model/test_base.py new file mode 100644 index 00000000..6431a504 --- /dev/null +++ b/tests/model/test_base.py @@ -0,0 +1,80 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict + +import pytest +import torch +from transformers import AutoModelForCausalLM +from trl import AutoModelForCausalLMWithValueHead + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "template": "llama3", + "infer_dtype": "float16", +} + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) + + +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init + + +def test_base(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + ref_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() + ) + compare_model(model, ref_model) + + +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") +def test_valuehead(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model( + tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True + ) + + ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( + TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model.v_head = ref_model.v_head.to(torch.float16) + compare_model(model, ref_model) diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py index c6cdec78..5f478af6 100644 --- a/tests/model/test_freeze.py +++ b/tests/model/test_freeze.py @@ -1,19 +1,33 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import torch -from llamafactory.hparams import get_train_args +from llamafactory.hparams import get_infer_args, get_train_args from llamafactory.model import load_model, load_tokenizer -TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM") +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") -TRAINING_ARGS = { +TRAIN_ARGS = { "model_name_or_path": TINY_LLAMA, "stage": "sft", "do_train": True, "finetuning_type": "freeze", - "dataset": "llamafactory/tiny_dataset", + "dataset": "llamafactory/tiny-supervised-dataset", "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, @@ -23,16 +37,19 @@ TRAINING_ARGS = { "fp16": True, } +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "finetuning_type": "freeze", + "template": "llama3", + "infer_dtype": "float16", +} -def test_freeze_all_modules(): - model_args, _, _, finetuning_args, _ = get_train_args( - { - "freeze_trainable_layers": 1, - **TRAINING_ARGS, - } - ) + +def test_freeze_train_all_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS}) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for name, param in model.named_parameters(): if name.startswith("model.layers.1."): assert param.requires_grad is True @@ -42,16 +59,13 @@ def test_freeze_all_modules(): assert param.dtype == torch.float16 -def test_freeze_extra_modules(): +def test_freeze_train_extra_modules(): model_args, _, _, finetuning_args, _ = get_train_args( - { - "freeze_trainable_layers": 1, - "freeze_extra_modules": "embed_tokens,lm_head", - **TRAINING_ARGS, - } + {"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS} ) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for name, param in model.named_parameters(): if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]): assert param.requires_grad is True @@ -59,3 +73,13 @@ def test_freeze_extra_modules(): else: assert param.requires_grad is False assert param.dtype == torch.float16 + + +def test_freeze_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + for param in model.parameters(): + assert param.requires_grad is False + assert param.dtype == torch.float16 diff --git a/tests/model/test_full.py b/tests/model/test_full.py index ef57a980..0a6e0743 100644 --- a/tests/model/test_full.py +++ b/tests/model/test_full.py @@ -1,19 +1,33 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import torch -from llamafactory.hparams import get_train_args +from llamafactory.hparams import get_infer_args, get_train_args from llamafactory.model import load_model, load_tokenizer -TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM") +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") -TRAINING_ARGS = { +TRAIN_ARGS = { "model_name_or_path": TINY_LLAMA, "stage": "sft", "do_train": True, "finetuning_type": "full", - "dataset": "llamafactory/tiny_dataset", + "dataset": "llamafactory/tiny-supervised-dataset", "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, @@ -23,11 +37,29 @@ TRAINING_ARGS = { "fp16": True, } +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "finetuning_type": "full", + "template": "llama3", + "infer_dtype": "float16", +} -def test_full(): - model_args, _, _, finetuning_args, _ = get_train_args(TRAINING_ARGS) + +def test_full_train(): + model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for param in model.parameters(): assert param.requires_grad is True assert param.dtype == torch.float32 + + +def test_full_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + for param in model.parameters(): + assert param.requires_grad is False + assert param.dtype == torch.float16 diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 1f2c02ae..630e5f75 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -1,19 +1,43 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +from typing import Dict, Sequence +import pytest import torch +from peft import LoraModel, PeftModel +from transformers import AutoModelForCausalLM +from trl import AutoModelForCausalLMWithValueHead -from llamafactory.hparams import get_train_args +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args, get_train_args from llamafactory.model import load_model, load_tokenizer -TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM") +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") -TRAINING_ARGS = { +TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora") + +TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") + +TRAIN_ARGS = { "model_name_or_path": TINY_LLAMA, "stage": "sft", "do_train": True, "finetuning_type": "lora", - "dataset": "llamafactory/tiny_dataset", + "dataset": "llamafactory/tiny-supervised-dataset", "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, @@ -23,16 +47,70 @@ TRAINING_ARGS = { "fp16": True, } +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "adapter_name_or_path": TINY_LLAMA_ADAPTER, + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", +} -def test_lora_all_modules(): - model_args, _, _, finetuning_args, _ = get_train_args( - { - "lora_target": "all", - **TRAINING_ARGS, - } + +def load_reference_model(is_trainable: bool = False) -> "LoraModel": + model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() ) + lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable) + for param in filter(lambda p: p.requires_grad, lora_model.parameters()): + param.data = param.data.to(torch.float32) + + return lora_model + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + if any(key in name for key in diff_keys): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False + else: + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True + + +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init + + +def test_lora_train_qv_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS}) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + linear_modules = set() + for name, param in model.named_parameters(): + if any(module in name for module in ["lora_A", "lora_B"]): + linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + assert linear_modules == {"q_proj", "v_proj"} + + +def test_lora_train_all_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + linear_modules = set() for name, param in model.named_parameters(): if any(module in name for module in ["lora_A", "lora_B"]): @@ -46,16 +124,13 @@ def test_lora_all_modules(): assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"} -def test_lora_extra_modules(): +def test_lora_train_extra_modules(): model_args, _, _, finetuning_args, _ = get_train_args( - { - "lora_target": "all", - "additional_target": "embed_tokens,lm_head", - **TRAINING_ARGS, - } + {"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS} ) tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + extra_modules = set() for name, param in model.named_parameters(): if any(module in name for module in ["lora_A", "lora_B"]): @@ -70,3 +145,54 @@ def test_lora_extra_modules(): assert param.dtype == torch.float16 assert extra_modules == {"embed_tokens", "lm_head"} + + +def test_lora_train_old_adapters(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + ref_model = load_reference_model(is_trainable=True) + compare_model(model, ref_model) + + +def test_lora_train_new_adapters(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + ref_model = load_reference_model(is_trainable=True) + compare_model( + model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] + ) + + +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") +def test_lora_train_valuehead(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model( + tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True + ) + + ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( + TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() + ) + state_dict = model.state_dict() + ref_state_dict = ref_model.state_dict() + + assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"]) + assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"]) + + +def test_lora_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + ref_model = load_reference_model().merge_and_unload() + compare_model(model, ref_model) diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py new file mode 100644 index 00000000..030310d0 --- /dev/null +++ b/tests/model/test_pissa.py @@ -0,0 +1,90 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from peft import LoraModel, PeftModel +from transformers import AutoModelForCausalLM + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "pissa_init": True, + "pissa_iter": -1, + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA_PISSA, + "adapter_name_or_path": TINY_LLAMA_PISSA, + "adapter_folder": "pissa_init", + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", +} + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) + + +def test_pissa_init(): + model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True) + for param in filter(lambda p: p.requires_grad, ref_model.parameters()): + param.data = param.data.to(torch.float32) + + compare_model(model, ref_model) + + +def test_pissa_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init") + ref_model = ref_model.merge_and_unload() + compare_model(model, ref_model)