diff --git a/.dockerignore b/.dockerignore index ce67d58a..2ac0e11d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,6 +4,8 @@ .venv cache data +hf_cache +output examples .dockerignore .gitattributes diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index ab2851c6..768adea6 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -13,6 +13,18 @@ body: - label: I have read the README and searched the existing issues. required: true + - type: textarea + id: system-info + validations: + required: true + attributes: + label: System Info + description: | + Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below. + 请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。 + + placeholder: llamafactory version, platform, python version, ... + - type: textarea id: reproduction validations: @@ -26,7 +38,9 @@ body: 请合理使用 Markdown 标签来格式化您的文本。 placeholder: | - python src/train_bash.py ... + ```bash + llamafactory-cli train ... + ``` - type: textarea id: expected-behavior @@ -38,18 +52,6 @@ body: Please provide a clear and concise description of what you would expect to happen. 请提供您原本的目的,即这段代码的期望行为。 - - type: textarea - id: system-info - validations: - required: false - attributes: - label: System Info - description: | - Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below. - 请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。 - - placeholder: transformers version, platform, python version, ... - - type: textarea id: others validations: diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b31e9d19..d23d6be3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -5,3 +5,4 @@ Fixes # (issue) ## Before submitting - [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)? +- [ ] Did you write any new necessary tests? diff --git a/.github/workflows/label_issue.yml b/.github/workflows/label_issue.yml new file mode 100644 index 00000000..b9a5543c --- /dev/null +++ b/.github/workflows/label_issue.yml @@ -0,0 +1,17 @@ +name: label_issue + +on: + issues: + types: + - opened + +jobs: + label_issue: + runs-on: ubuntu-latest + + steps: + - env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ISSUE_URL: ${{ github.event.issue.html_url }} + run: | + gh issue edit $ISSUE_URL --add-label "pending" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f891f711..98bd9455 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,28 +2,44 @@ name: tests on: push: - branches: [ "main" ] + branches: + - main + paths: + - "**.py" + - "requirements.txt" + - ".github/workflows/*.yml" pull_request: - branches: [ "main" ] + branches: + - main + paths: + - "**.py" + - "requirements.txt" + - ".github/workflows/*.yml" jobs: - check_code_quality: - + tests: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: Checkout + uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.8" + cache: "pip" + cache-dependency-path: "setup.py" - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install ruff + python -m pip install .[torch,dev] - name: Check quality run: | - make style && make quality + make style && make quality + + - name: Test with pytest + run: | + make test diff --git a/Dockerfile b/Dockerfile index 0a35e355..3932ff30 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,14 +1,44 @@ -FROM nvcr.io/nvidia/pytorch:24.01-py3 +# Use the NVIDIA official image with PyTorch 2.3.0 +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html +FROM nvcr.io/nvidia/pytorch:24.02-py3 +# Define installation arguments +ARG INSTALL_BNB=false +ARG INSTALL_VLLM=false +ARG INSTALL_DEEPSPEED=false +ARG PIP_INDEX=https://pypi.org/simple + +# Set the working directory WORKDIR /app +# Install the requirements COPY requirements.txt /app/ -RUN pip install -r requirements.txt +RUN pip config set global.index-url $PIP_INDEX +RUN python -m pip install --upgrade pip +RUN python -m pip install -r requirements.txt +# Copy the rest of the application into the image COPY . /app/ -RUN pip install -e .[metrics,bitsandbytes,qwen] +# Install the LLaMA Factory +RUN EXTRA_PACKAGES="metrics"; \ + if [ "$INSTALL_BNB" = "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \ + fi; \ + if [ "$INSTALL_VLLM" = "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \ + fi; \ + if [ "$INSTALL_DEEPSPEED" = "true" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ + fi; \ + pip install -e .[$EXTRA_PACKAGES] && \ + pip uninstall -y transformer-engine flash-attn + +# Set up volumes VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] + +# Expose port 7860 for the LLaMA Board EXPOSE 7860 -CMD [ "llamafactory-cli", "webui" ] +# Expose port 8000 for the API service +EXPOSE 8000 diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..82c51f63 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE requirements.txt diff --git a/Makefile b/Makefile index 3a4a12c9..3f13b215 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: quality style +.PHONY: quality style test check_dirs := scripts src tests @@ -9,3 +9,6 @@ quality: style: ruff check $(check_dirs) --fix ruff format $(check_dirs) + +test: + CUDA_VISIBLE_DEVICES= pytest tests/ diff --git a/README.md b/README.md index 78312e07..cb9a7222 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,10 @@ [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) +[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) @@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89 Choose your path: - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing +- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **Local machine**: Please refer to [usage](#getting-started) ## Table of Contents @@ -45,9 +47,9 @@ Choose your path: ## Features - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. -- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO. +- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. +- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -69,14 +71,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. + +[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. + [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. +
Full Changelog + [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. -
Full Changelog - [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. @@ -145,38 +151,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Supported Models -| Model | Model size | Default module | Template | -| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | -| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | -| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | -| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | -| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | -| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon | -| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | -| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | -| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | -| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | -| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | -| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | -| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi | -| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | -| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | -| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | -| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | -| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | +| Model | Model size | Template | +| -------------------------------------------------------- | -------------------------------- | --------- | +| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | +| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | +| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | +| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] -> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence. -> -> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. +> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > > Remember to use the **SAME** template in training and inference. @@ -208,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb) +- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) @@ -251,6 +259,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) @@ -267,6 +276,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
Preference datasets - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) +- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) @@ -286,21 +296,21 @@ huggingface-cli login | Mandatory | Minimum | Recommend | | ------------ | ------- | --------- | -| python | 3.8 | 3.10 | -| torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.41.0 | -| datasets | 2.14.3 | 2.19.1 | -| accelerate | 0.27.2 | 0.30.1 | -| peft | 0.9.0 | 0.11.1 | -| trl | 0.8.2 | 0.8.6 | +| python | 3.8 | 3.11 | +| torch | 1.13.1 | 2.3.0 | +| transformers | 4.41.2 | 4.41.2 | +| datasets | 2.16.0 | 2.19.2 | +| accelerate | 0.30.1 | 0.30.1 | +| peft | 0.11.1 | 0.11.1 | +| trl | 0.8.6 | 0.9.4 | | Optional | Minimum | Recommend | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.14.0 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.0 | 0.4.2 | -| flash-attn | 2.3.0 | 2.5.8 | +| vllm | 0.4.3 | 0.4.3 | +| flash-attn | 2.3.0 | 2.5.9 | ### Hardware Requirement @@ -326,10 +336,10 @@ huggingface-cli login ```bash git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory -pip install -e .[torch,metrics] +pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality +Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. @@ -350,14 +360,28 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec Join [NPU user group](assets/wechat_npu.jpg). -To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. +To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands: -| Requirement | Minimum | Recommend | -| ------------ | ------- | --------- | -| CANN | 8.0.RC1 | 8.0.RC1 | -| torch | 2.2.0 | 2.2.0 | -| torch-npu | 2.2.0 | 2.2.0 | -| deepspeed | 0.13.2 | 0.13.2 | +```bash +# replace the url according to your CANN version and devices +# install CANN Toolkit +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run +bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install + +# install CANN Kernels +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run +bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install + +# set env variables +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +| Requirement | Minimum | Recommend | +| ------------ | ------- | ----------- | +| CANN | 8.0.RC1 | 8.0.RC1 | +| torch | 2.1.0 | 2.1.0 | +| torch-npu | 2.1.0 | 2.1.0.post3 | +| deepspeed | 0.13.2 | 0.13.2 | Docker image: @@ -382,9 +406,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` See [examples/README.md](examples/README.md) for advanced usage (including distributed training). @@ -394,36 +418,38 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) -> [!IMPORTANT] -> LLaMA Board GUI only supports training on a single GPU. - -#### Use local environment - ```bash -CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui +llamafactory-cli webui ``` -
+### Build Docker #### Use Docker ```bash -docker build -f ./Dockerfile -t llama-factory:latest . -docker run --gpus=all \ +docker build -f ./Dockerfile \ + --build-arg INSTALL_BNB=false \ + --build-arg INSTALL_VLLM=false \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +docker run -it --gpus=all \ -v ./hf_cache:/root/.cache/huggingface/ \ -v ./data:/app/data \ -v ./output:/app/output \ - -e CUDA_VISIBLE_DEVICES=0 \ -p 7860:7860 \ + -p 8000:8000 \ --shm-size 16G \ - --name llama_factory \ - -d llama-factory:latest + --name llamafactory \ + llamafactory:latest ``` #### Use Docker Compose ```bash -docker compose -f ./docker-compose.yml up -d +docker-compose up -d +docker-compose exec llamafactory bash ```
Details about volume @@ -437,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d ### Deploy with OpenAI-style API and vLLM ```bash -CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` +> [!TIP] +> Visit https://platform.openai.com/docs/api-reference/chat/create for API document. + ### Download from ModelScope Hub If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. @@ -448,7 +477,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows ``` -Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. +Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. + +### Use W&B Logger + +To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments. + +```yaml +report_to: wandb +run_name: test_run # optional +``` + +Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account. ## Projects using LLaMA Factory @@ -507,7 +547,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index 5acf3dd1..5c005f30 100644 --- a/README_zh.md +++ b/README_zh.md @@ -8,9 +8,10 @@ [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing) +[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) @@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd 选择你的打开方式: - **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing +- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **本地机器**:请见[如何使用](#如何使用) ## 目录 @@ -45,9 +47,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 项目特色 - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 -- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。 +- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 -- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 +- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -69,14 +71,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 + +[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 + [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 +
展开日志 + [24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 -
展开日志 - [24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。 [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 @@ -145,40 +151,40 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 模型 -| 模型名 | 模型大小 | 默认模块 | Template | -| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | -| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | -| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | -| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | -| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | -| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon | -| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | -| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | -| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | -| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | -| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | -| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | -| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi | -| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | -| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | -| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | -| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | -| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | +| 模型名 | 模型大小 | Template | +| -------------------------------------------------------- | -------------------------------- | --------- | +| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | +| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | +| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | +| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] -> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果。 +> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。 > -> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。 -> -> 请务必在训练和推理时使用**完全一致**的模板。 +> 请务必在训练和推理时采用**完全一致**的模板。 项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。 @@ -208,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) +- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb) +- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) @@ -251,6 +259,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) +- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) @@ -267,6 +276,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
偏好数据集 - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) +- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) @@ -286,21 +296,21 @@ huggingface-cli login | 必需项 | 至少 | 推荐 | | ------------ | ------- | --------- | -| python | 3.8 | 3.10 | -| torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.41.0 | -| datasets | 2.14.3 | 2.19.1 | -| accelerate | 0.27.2 | 0.30.1 | -| peft | 0.9.0 | 0.11.1 | -| trl | 0.8.2 | 0.8.6 | +| python | 3.8 | 3.11 | +| torch | 1.13.1 | 2.3.0 | +| transformers | 4.41.2 | 4.41.2 | +| datasets | 2.16.0 | 2.19.2 | +| accelerate | 0.30.1 | 0.30.1 | +| peft | 0.11.1 | 0.11.1 | +| trl | 0.8.6 | 0.9.4 | | 可选项 | 至少 | 推荐 | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | | deepspeed | 0.10.0 | 0.14.0 | | bitsandbytes | 0.39.0 | 0.43.1 | -| vllm | 0.4.0 | 0.4.2 | -| flash-attn | 2.3.0 | 2.5.8 | +| vllm | 0.4.3 | 0.4.3 | +| flash-attn | 2.3.0 | 2.5.9 | ### 硬件依赖 @@ -326,10 +336,10 @@ huggingface-cli login ```bash git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory -pip install -e .[torch,metrics] +pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality +可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 @@ -350,21 +360,35 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl 加入 [NPU 用户群](assets/wechat_npu.jpg)。 -如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。 +在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: -| 依赖项 | 至少 | 推荐 | -| ------------ | ------- | --------- | -| CANN | 8.0.RC1 | 8.0.RC1 | -| torch | 2.2.0 | 2.2.0 | -| torch-npu | 2.2.0 | 2.2.0 | -| deepspeed | 0.13.2 | 0.13.2 | +```bash +# 请替换 URL 为 CANN 版本和设备型号对应的 URL +# 安装 CANN Toolkit +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run +bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install + +# 安装 CANN Kernels +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run +bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +| 依赖项 | 至少 | 推荐 | +| ------------ | ------- | ----------- | +| CANN | 8.0.RC1 | 8.0.RC1 | +| torch | 2.1.0 | 2.1.0 | +| torch-npu | 2.1.0 | 2.1.0.post3 | +| deepspeed | 0.13.2 | 0.13.2 | Docker 镜像: - 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) -请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 +请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。 @@ -382,9 +406,9 @@ Docker 镜像: 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` 高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。 @@ -394,34 +418,38 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) -> [!IMPORTANT] -> LLaMA Board 可视化界面目前仅支持单 GPU 训练。 - -#### 使用本地环境 - ```bash -CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui +llamafactory-cli webui ``` +### 构建 Docker + #### 使用 Docker ```bash -docker build -f ./Dockerfile -t llama-factory:latest . -docker run --gpus=all \ +docker build -f ./Dockerfile \ + --build-arg INSTALL_BNB=false \ + --build-arg INSTALL_VLLM=false \ + --build-arg INSTALL_DEEPSPEED=false \ + --build-arg PIP_INDEX=https://pypi.org/simple \ + -t llamafactory:latest . + +docker run -it --gpus=all \ -v ./hf_cache:/root/.cache/huggingface/ \ -v ./data:/app/data \ -v ./output:/app/output \ - -e CUDA_VISIBLE_DEVICES=0 \ -p 7860:7860 \ + -p 8000:8000 \ --shm-size 16G \ - --name llama_factory \ - -d llama-factory:latest + --name llamafactory \ + llamafactory:latest ``` #### 使用 Docker Compose ```bash -docker compose -f ./docker-compose.yml up -d +docker-compose up -d +docker-compose exec llamafactory bash ```
数据卷详情 @@ -435,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d ### 利用 vLLM 部署 OpenAI API ```bash -CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml +API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml ``` +> [!TIP] +> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。 + ### 从魔搭社区下载 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 @@ -446,7 +477,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` ``` -将 `--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。 +将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。 + +### 使用 W&B 面板 + +若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。 + +```yaml +report_to: wandb +run_name: test_run # 可选 +``` + +在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。 ## 使用了 LLaMA Factory 的项目 @@ -505,7 +547,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/assets/wechat.jpg b/assets/wechat.jpg index d1f1b609..5868f16c 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/assets/wechat_npu.jpg b/assets/wechat_npu.jpg index 353e7603..cf7019b5 100644 Binary files a/assets/wechat_npu.jpg and b/assets/wechat_npu.jpg differ diff --git a/data/README.md b/data/README.md index 3e96fbeb..5ceae666 100644 --- a/data/README.md +++ b/data/README.md @@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "ranking": "whether the dataset is a preference dataset or not. (default: False)", "subset": "the name of the subset. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", + "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", "columns (optional)": { "prompt": "the column name in the dataset containing the prompts. (default: instruction)", "query": "the column name in the dataset containing the queries. (default: input)", diff --git a/data/README_zh.md b/data/README_zh.md index aff6fdb1..1795f352 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -12,6 +12,7 @@ "ranking": "是否为偏好数据集(可选,默认:False)", "subset": "数据集子集的名称(可选,默认:None)", "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)", + "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)", "columns(可选)": { "prompt": "数据集代表提示词的表头名称(默认:instruction)", "query": "数据集代表请求的表头名称(默认:input)", diff --git a/data/dataset_info.json b/data/dataset_info.json index 7420f3a8..1d226b3a 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -248,6 +248,10 @@ "ruozhiba_gpt4": { "hf_hub_url": "hfl/ruozhiba_gpt4_turbo" }, + "neo_sft": { + "hf_hub_url": "m-a-p/neo_sft_phase2", + "formatting": "sharegpt" + }, "llava_1k_en": { "hf_hub_url": "BUAADreamer/llava-en-zh-2k", "subset": "en", @@ -308,6 +312,20 @@ "assistant_tag": "assistant" } }, + "mllm_pt_demo": { + "hf_hub_url": "BUAADreamer/mllm_pt_demo", + "formatting": "sharegpt", + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } + }, "oasst_de": { "hf_hub_url": "mayflowergmbh/oasst_de" }, @@ -377,6 +395,16 @@ "rejected": "rejected" } }, + "ultrafeedback": { + "hf_hub_url": "llamafactory/ultrafeedback_binarized", + "ms_hub_url": "llamafactory/ultrafeedback_binarized", + "ranking": true, + "columns": { + "prompt": "instruction", + "chosen": "chosen", + "rejected": "rejected" + } + }, "orca_pairs": { "hf_hub_url": "Intel/orca_dpo_pairs", "ranking": true, @@ -434,6 +462,15 @@ "assistant_tag": "assistant" } }, + "ultrafeedback_kto": { + "hf_hub_url": "argilla/ultrafeedback-binarized-preferences-cleaned-kto", + "ms_hub_url": "AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto", + "columns": { + "prompt": "prompt", + "response": "completion", + "kto_tag": "label" + } + }, "wiki_demo": { "file_name": "wiki_demo.txt", "columns": { @@ -487,6 +524,18 @@ "prompt": "text" } }, + "fileweb": { + "hf_hub_url": "HuggingFaceFW/fineweb", + "columns": { + "prompt": "text" + } + }, + "fileweb_edu": { + "hf_hub_url": "HuggingFaceFW/fineweb-edu", + "columns": { + "prompt": "text" + } + }, "the_stack": { "hf_hub_url": "bigcode/the-stack", "ms_hub_url": "AI-ModelScope/the-stack", diff --git a/docker-compose.yml b/docker-compose.yml index 333dc51e..c5dc34e9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,20 +1,25 @@ -version: '3.8' - services: - llama-factory: + llamafactory: build: dockerfile: Dockerfile context: . - container_name: llama_factory + args: + INSTALL_BNB: false + INSTALL_VLLM: false + INSTALL_DEEPSPEED: false + PIP_INDEX: https://pypi.org/simple + container_name: llamafactory volumes: - ./hf_cache:/root/.cache/huggingface/ - ./data:/app/data - ./output:/app/output - environment: - - CUDA_VISIBLE_DEVICES=0 ports: - "7860:7860" + - "8000:8000" ipc: host + tty: true + stdin_open: true + command: bash deploy: resources: reservations: diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py index 4111d6b4..48442d50 100644 --- a/evaluation/ceval/ceval.py +++ b/evaluation/ceval/ceval.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py index 37efb328..5ff548a4 100644 --- a/evaluation/cmmlu/cmmlu.py +++ b/evaluation/cmmlu/cmmlu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py index f3218c38..1065fb31 100644 --- a/evaluation/mmlu/mmlu.py +++ b/evaluation/mmlu/mmlu.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import datasets @@ -154,7 +155,7 @@ class MMLU(datasets.GeneratorBasedBuilder): ] def _generate_examples(self, filepath): - df = pd.read_csv(filepath) + df = pd.read_csv(filepath, header=None) df.columns = ["question", "A", "B", "C", "D", "answer"] for i, instance in enumerate(df.to_dict(orient="records")): diff --git a/examples/README.md b/examples/README.md index 9c6d5fb0..007a81ab 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory. ## Table of Contents -- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) -- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) -- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) -- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus) -- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus) +- [LoRA Fine-Tuning](#lora-fine-tuning) +- [QLoRA Fine-Tuning](#qlora-fine-tuning) +- [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) - [Extras](#extras) +Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices. + ## Examples -### LoRA Fine-Tuning on A Single GPU +### LoRA Fine-Tuning #### (Continuous) Pre-Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml +llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml ``` #### Supervised Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### Multimodal Supervised Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml ``` #### Reward Modeling ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml +llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml ``` #### PPO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` #### DPO/ORPO/SimPO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ``` #### KTO Training ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml +llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml ``` #### Preprocess Dataset @@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml +llamafactory-cli train examples/train_lora/llama3_preprocess.yaml ``` #### Evaluating on MMLU/CMMLU/C-Eval Benchmarks ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml +llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` #### Batch Predicting and Computing BLEU and ROUGE Scores ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml +llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml ``` -### QLoRA Fine-Tuning on a Single GPU - -#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended) +#### Supervised Fine-Tuning on Multiple Nodes ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml -``` - -#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml -``` - -#### Supervised Fine-Tuning with 4-bit AWQ Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml -``` - -#### Supervised Fine-Tuning with 2-bit AQLM Quantization - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml -``` - -### LoRA Fine-Tuning on Multiple GPUs - -#### Supervised Fine-Tuning with Accelerate on Single Node - -```bash -bash examples/lora_multi_gpu/single_node.sh -``` - -#### Supervised Fine-Tuning with Accelerate on Multiple Nodes - -```bash -bash examples/lora_multi_gpu/multi_node.sh +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) ```bash -bash examples/lora_multi_gpu/ds_zero3.sh +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` -### LoRA Fine-Tuning on Multiple NPUs +### QLoRA Fine-Tuning -#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 +#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended) ```bash -bash examples/lora_multi_npu/ds_zero0.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml ``` -### Full-Parameter Fine-Tuning on Multiple GPUs - -#### Supervised Fine-Tuning with Accelerate on Single Node +#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization ```bash -bash examples/full_multi_gpu/single_node.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml ``` -#### Supervised Fine-Tuning with Accelerate on Multiple Nodes +#### Supervised Fine-Tuning with 4-bit AWQ Quantization ```bash -bash examples/full_multi_gpu/multi_node.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml +``` + +#### Supervised Fine-Tuning with 2-bit AQLM Quantization + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml +``` + +### Full-Parameter Fine-Tuning + +#### Supervised Fine-Tuning on Single Node + +```bash +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +``` + +#### Supervised Fine-Tuning on Multiple Nodes + +```bash +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### Batch Predicting and Computing BLEU and ROUGE Scores ```bash -bash examples/full_multi_gpu/predict.sh +llamafactory-cli train examples/train_full/llama3_full_predict.yaml ``` ### Merging LoRA Adapters and Quantization @@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` #### Quantizing Model using AutoGPTQ ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml +llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` ### Inferring LoRA Fine-Tuned Models -Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices. - #### Use CLI ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` #### Use Web UI ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml +llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml ``` #### Launch OpenAI-style API ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml +llamafactory-cli api examples/inference/llama3_lora_sft.yaml ``` ### Extras @@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y #### Full-Parameter Fine-Tuning using GaLore ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml +llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` #### Full-Parameter Fine-Tuning using BAdam ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml +llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` #### LoRA+ Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +``` + +#### PiSSA Fine-Tuning + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml ``` #### Mixture-of-Depths Fine-Tuning ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml +llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml ``` #### LLaMA-Pro Fine-Tuning ```bash bash examples/extras/llama_pro/expand.sh -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml +llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ``` #### FSDP+QLoRA Fine-Tuning ```bash -bash examples/extras/fsdp_qlora/single_node.sh +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 0ff33398..b9d90f25 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -4,59 +4,59 @@ ## 目录 -- [单 GPU LoRA 微调](#单-gpu-lora-微调) -- [单 GPU QLoRA 微调](#单-gpu-qlora-微调) -- [多 GPU LoRA 微调](#多-gpu-lora-微调) -- [多 NPU LoRA 微调](#多-npu-lora-微调) -- [多 GPU 全参数微调](#多-gpu-全参数微调) +- [LoRA 微调](#lora-微调) +- [QLoRA 微调](#qlora-微调) +- [全参数微调](#全参数微调) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [推理 LoRA 模型](#推理-lora-模型) - [杂项](#杂项) +使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。 + ## 示例 -### 单 GPU LoRA 微调 +### LoRA 微调 #### (增量)预训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml +llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml ``` #### 指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml +llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### 多模态指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml ``` #### 奖励模型训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml +llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml ``` #### PPO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` #### DPO/ORPO/SimPO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml ``` #### KTO 训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml +llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml ``` #### 预处理数据集 @@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml +llamafactory-cli train examples/train_lora/llama3_preprocess.yaml ``` #### 在 MMLU/CMMLU/C-Eval 上评估 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml +llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml ``` #### 批量预测并计算 BLEU 和 ROUGE 分数 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml +llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml ``` -### 单 GPU QLoRA 微调 - -#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐) +#### 多机指令监督微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml -``` - -#### 基于 4/8 比特 GPTQ 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml -``` - -#### 基于 4 比特 AWQ 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml -``` - -#### 基于 2 比特 AQLM 量化进行指令监督微调 - -```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml -``` - -### 多 GPU LoRA 微调 - -#### 使用 Accelerate 进行单节点训练 - -```bash -bash examples/lora_multi_gpu/single_node.sh -``` - -#### 使用 Accelerate 进行多节点训练 - -```bash -bash examples/lora_multi_gpu/multi_node.sh +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ``` #### 使用 DeepSpeed ZeRO-3 平均分配显存 ```bash -bash examples/lora_multi_gpu/ds_zero3.sh +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml ``` -### 多 NPU LoRA 微调 +### QLoRA 微调 -#### 使用 DeepSpeed ZeRO-0 训练 +#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐) ```bash -bash examples/lora_multi_npu/ds_zero0.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml ``` -### 多 GPU 全参数微调 - -#### 使用 DeepSpeed 进行单节点训练 +#### 基于 4/8 比特 GPTQ 量化进行指令监督微调 ```bash -bash examples/full_multi_gpu/single_node.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml ``` -#### 使用 DeepSpeed 进行多节点训练 +#### 基于 4 比特 AWQ 量化进行指令监督微调 ```bash -bash examples/full_multi_gpu/multi_node.sh +llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml +``` + +#### 基于 2 比特 AQLM 量化进行指令监督微调 + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml +``` + +### 全参数微调 + +#### 在单机上进行指令监督微调 + +```bash +FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +``` + +#### 在多机上进行指令监督微调 + +```bash +FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml +FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml ``` #### 批量预测并计算 BLEU 和 ROUGE 分数 ```bash -bash examples/full_multi_gpu/predict.sh +llamafactory-cli train examples/train_full/llama3_full_predict.yaml ``` ### 合并 LoRA 适配器与模型量化 @@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml +llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml ``` #### 使用 AutoGPTQ 量化模型 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml +llamafactory-cli export examples/merge_lora/llama3_gptq.yaml ``` ### 推理 LoRA 模型 -使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。 - #### 使用命令行接口 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml +llamafactory-cli chat examples/inference/llama3_lora_sft.yaml ``` #### 使用浏览器界面 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml +llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml ``` #### 启动 OpenAI 风格 API ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml +llamafactory-cli api examples/inference/llama3_lora_sft.yaml ``` ### 杂项 @@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y #### 使用 GaLore 进行全参数训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml +llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml ``` #### 使用 BAdam 进行全参数训练 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml +llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` #### LoRA+ 微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml +``` + +#### PiSSA 微调 + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml ``` #### 深度混合微调 ```bash -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml +llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml ``` #### LLaMA-Pro 微调 ```bash bash examples/extras/llama_pro/expand.sh -CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml +llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ``` #### FSDP+QLoRA 微调 ```bash -bash examples/extras/fsdp_qlora/single_node.sh +bash examples/extras/fsdp_qlora/train.sh ``` diff --git a/examples/accelerate/fsdp_config.yaml b/examples/accelerate/fsdp_config.yaml index 60025597..cd65e074 100644 --- a/examples/accelerate/fsdp_config.yaml +++ b/examples/accelerate/fsdp_config.yaml @@ -5,16 +5,16 @@ downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: BACKWARD_PRE - fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: false - fsdp_offload_params: true + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: true # offload may affect training speed fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: FULL_STATE_DICT fsdp_sync_module_states: true - fsdp_use_orig_params: false + fsdp_use_orig_params: true machine_rank: 0 main_training_function: main -mixed_precision: fp16 +mixed_precision: fp16 # or bf16 num_machines: 1 # the number of nodes num_processes: 2 # the number of GPUs in all nodes rdzv_backend: static diff --git a/examples/accelerate/master_config.yaml b/examples/accelerate/master_config.yaml deleted file mode 100644 index a1018313..00000000 --- a/examples/accelerate/master_config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: MULTI_GPU -downcast_bf16: 'no' -gpu_ids: all -machine_rank: 0 -main_process_ip: 192.168.0.1 -main_process_port: 29555 -main_training_function: main -mixed_precision: fp16 -num_machines: 2 # the number of nodes -num_processes: 8 # the number of GPUs in all nodes -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/accelerate/single_config.yaml b/examples/accelerate/single_config.yaml deleted file mode 100644 index 97f8c633..00000000 --- a/examples/accelerate/single_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: MULTI_GPU -downcast_bf16: 'no' -gpu_ids: all -machine_rank: 0 -main_training_function: main -mixed_precision: fp16 -num_machines: 1 # the number of nodes -num_processes: 4 # the number of GPUs in all nodes -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/accelerate/slave_config.yaml b/examples/accelerate/slave_config.yaml deleted file mode 100644 index e610fd0e..00000000 --- a/examples/accelerate/slave_config.yaml +++ /dev/null @@ -1,18 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: MULTI_GPU -downcast_bf16: 'no' -gpu_ids: all -machine_rank: 1 -main_process_ip: 192.168.0.1 -main_process_port: 29555 -main_training_function: main -mixed_precision: fp16 -num_machines: 2 # the number of nodes -num_processes: 8 # the number of GPUs in all nodes -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/extras/badam/llama3_lora_sft.yaml b/examples/extras/badam/llama3_lora_sft.yaml index 4a482749..a78de2fa 100644 --- a/examples/extras/badam/llama3_lora_sft.yaml +++ b/examples/extras/badam/llama3_lora_sft.yaml @@ -28,14 +28,14 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 pure_bf16: true ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml index e9c04fa9..cc773991 100644 --- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml +++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml @@ -6,10 +6,7 @@ quantization_bit: 4 stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj - -### ddp -ddp_timeout: 180000000 +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -29,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/extras/fsdp_qlora/single_node.sh b/examples/extras/fsdp_qlora/train.sh similarity index 68% rename from examples/extras/fsdp_qlora/single_node.sh rename to examples/extras/fsdp_qlora/train.sh index 54ec2bd2..fac8cdee 100644 --- a/examples/extras/fsdp_qlora/single_node.sh +++ b/examples/extras/fsdp_qlora/train.sh @@ -1,10 +1,6 @@ #!/bin/bash # DO NOT use GPTQ/AWQ model in FSDP+QLoRA -pip install "transformers>=4.39.1" -pip install "accelerate>=0.28.0" -pip install "bitsandbytes>=0.43.0" - CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ --config_file examples/accelerate/fsdp_config.yaml \ src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml diff --git a/examples/extras/galore/llama3_full_sft.yaml b/examples/extras/galore/llama3_full_sft.yaml index 87381fcc..605545de 100644 --- a/examples/extras/galore/llama3_full_sft.yaml +++ b/examples/extras/galore/llama3_full_sft.yaml @@ -29,14 +29,14 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 1 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 pure_bf16: true ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml index 8ace8db8..f92d6945 100644 --- a/examples/extras/llama_pro/llama3_freeze_sft.yaml +++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml @@ -27,14 +27,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml index 26c2b1d2..57383ae0 100644 --- a/examples/extras/loraplus/llama3_lora_sft.yaml +++ b/examples/extras/loraplus/llama3_lora_sft.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all loraplus_lr_ratio: 16.0 ### dataset @@ -26,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml index 6b724ed0..085febfc 100644 --- a/examples/extras/mod/llama3_full_sft.yaml +++ b/examples/extras/mod/llama3_full_sft.yaml @@ -26,14 +26,15 @@ overwrite_output_dir: true per_device_train_batch_size: 1 gradient_accumulation_steps: 8 optim: paged_adamw_8bit -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 pure_bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_multi_gpu/llama3_lora_sft.yaml b/examples/extras/pissa/llama3_lora_sft.yaml similarity index 78% rename from examples/lora_multi_gpu/llama3_lora_sft.yaml rename to examples/extras/pissa/llama3_lora_sft.yaml index 6389f21b..fd4b9f1d 100644 --- a/examples/lora_multi_gpu/llama3_lora_sft.yaml +++ b/examples/extras/pissa/llama3_lora_sft.yaml @@ -5,10 +5,10 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj - -### ddp -ddp_timeout: 180000000 +lora_target: all +pissa_init: true +pissa_iter: 4 +pissa_convert: true ### dataset dataset: identity,alpaca_en_demo @@ -27,15 +27,16 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 -gradient_accumulation_steps: 2 -learning_rate: 0.0001 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/full_multi_gpu/multi_node.sh b/examples/full_multi_gpu/multi_node.sh deleted file mode 100644 index 34c038d4..00000000 --- a/examples/full_multi_gpu/multi_node.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -NPROC_PER_NODE=4 -NNODES=2 -RANK=0 -MASTER_ADDR=192.168.0.1 -MASTER_PORT=29500 - -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ - --nproc_per_node $NPROC_PER_NODE \ - --nnodes $NNODES \ - --node_rank $RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ - src/train.py examples/full_multi_gpu/llama3_full_sft.yaml diff --git a/examples/full_multi_gpu/predict.sh b/examples/full_multi_gpu/predict.sh deleted file mode 100644 index 2445f444..00000000 --- a/examples/full_multi_gpu/predict.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ - --config_file examples/accelerate/single_config.yaml \ - src/train.py examples/full_multi_gpu/llama3_full_predict.yaml diff --git a/examples/full_multi_gpu/single_node.sh b/examples/full_multi_gpu/single_node.sh deleted file mode 100644 index ac29c097..00000000 --- a/examples/full_multi_gpu/single_node.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -NPROC_PER_NODE=4 -NNODES=1 -RANK=0 -MASTER_ADDR=127.0.0.1 -MASTER_PORT=29500 - -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ - --nproc_per_node $NPROC_PER_NODE \ - --nnodes $NNODES \ - --node_rank $RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ - src/train.py examples/full_multi_gpu/llama3_full_sft.yaml diff --git a/examples/lora_multi_gpu/ds_zero3.sh b/examples/lora_multi_gpu/ds_zero3.sh deleted file mode 100644 index 90ea00dd..00000000 --- a/examples/lora_multi_gpu/ds_zero3.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -NPROC_PER_NODE=4 -NNODES=1 -RANK=0 -MASTER_ADDR=127.0.0.1 -MASTER_PORT=29500 - -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ - --nproc_per_node $NPROC_PER_NODE \ - --nnodes $NNODES \ - --node_rank $RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ - src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml diff --git a/examples/lora_multi_gpu/multi_node.sh b/examples/lora_multi_gpu/multi_node.sh deleted file mode 100644 index 401fac5f..00000000 --- a/examples/lora_multi_gpu/multi_node.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -# also launch it on slave machine using slave_config.yaml - -CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ - --config_file examples/accelerate/master_config.yaml \ - src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml diff --git a/examples/lora_multi_gpu/single_node.sh b/examples/lora_multi_gpu/single_node.sh deleted file mode 100644 index 885a0e8c..00000000 --- a/examples/lora_multi_gpu/single_node.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \ - --config_file examples/accelerate/single_config.yaml \ - src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml diff --git a/examples/lora_multi_npu/ds_zero0.sh b/examples/lora_multi_npu/ds_zero0.sh deleted file mode 100644 index 4ffaa1b0..00000000 --- a/examples/lora_multi_npu/ds_zero0.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -NPROC_PER_NODE=4 -NNODES=1 -RANK=0 -MASTER_ADDR=127.0.0.1 -MASTER_PORT=29500 - -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \ - --nproc_per_node $NPROC_PER_NODE \ - --nnodes $NNODES \ - --node_rank $RANK \ - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ - src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml diff --git a/examples/full_multi_gpu/llama3_full_predict.yaml b/examples/train_full/llama3_full_predict.yaml similarity index 100% rename from examples/full_multi_gpu/llama3_full_predict.yaml rename to examples/train_full/llama3_full_predict.yaml diff --git a/examples/full_multi_gpu/llama3_full_sft.yaml b/examples/train_full/llama3_full_sft_ds3.yaml similarity index 89% rename from examples/full_multi_gpu/llama3_full_sft.yaml rename to examples/train_full/llama3_full_sft_ds3.yaml index a96f1b8e..40afd2ee 100644 --- a/examples/full_multi_gpu/llama3_full_sft.yaml +++ b/examples/train_full/llama3_full_sft_ds3.yaml @@ -5,9 +5,6 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: full - -### ddp -ddp_timeout: 180000000 deepspeed: examples/deepspeed/ds_z3_config.json ### dataset @@ -28,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 2 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_lora_dpo.yaml b/examples/train_lora/llama3_lora_dpo.yaml similarity index 86% rename from examples/lora_single_gpu/llama3_lora_dpo.yaml rename to examples/train_lora/llama3_lora_dpo.yaml index f68244b7..db25fb51 100644 --- a/examples/lora_single_gpu/llama3_lora_dpo.yaml +++ b/examples/train_lora/llama3_lora_dpo.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: dpo do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all pref_beta: 0.1 pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo] @@ -27,14 +27,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.000005 +learning_rate: 5.0e-6 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_lora_eval.yaml b/examples/train_lora/llama3_lora_eval.yaml similarity index 100% rename from examples/lora_single_gpu/llama3_lora_eval.yaml rename to examples/train_lora/llama3_lora_eval.yaml diff --git a/examples/lora_single_gpu/llama3_lora_kto.yaml b/examples/train_lora/llama3_lora_kto.yaml similarity index 83% rename from examples/lora_single_gpu/llama3_lora_kto.yaml rename to examples/train_lora/llama3_lora_kto.yaml index 4405aaec..f730c82e 100644 --- a/examples/lora_single_gpu/llama3_lora_kto.yaml +++ b/examples/train_lora/llama3_lora_kto.yaml @@ -5,7 +5,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: kto do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all +pref_beta: 0.1 ### dataset dataset: kto_en_demo @@ -25,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.000005 +learning_rate: 5.0e-6 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_lora_ppo.yaml b/examples/train_lora/llama3_lora_ppo.yaml similarity index 88% rename from examples/lora_single_gpu/llama3_lora_ppo.yaml rename to examples/train_lora/llama3_lora_ppo.yaml index 88ce24f3..e574014e 100644 --- a/examples/lora_single_gpu/llama3_lora_ppo.yaml +++ b/examples/train_lora/llama3_lora_ppo.yaml @@ -6,7 +6,7 @@ reward_model: saves/llama3-8b/lora/reward stage: ppo do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -26,11 +26,12 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.00001 +learning_rate: 1.0e-5 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### generate max_new_tokens: 512 diff --git a/examples/lora_single_gpu/llama3_lora_predict.yaml b/examples/train_lora/llama3_lora_predict.yaml similarity index 95% rename from examples/lora_single_gpu/llama3_lora_predict.yaml rename to examples/train_lora/llama3_lora_predict.yaml index a127d248..148c8635 100644 --- a/examples/lora_single_gpu/llama3_lora_predict.yaml +++ b/examples/train_lora/llama3_lora_predict.yaml @@ -22,3 +22,4 @@ overwrite_output_dir: true ### eval per_device_eval_batch_size: 1 predict_with_generate: true +ddp_timeout: 180000000 diff --git a/examples/lora_single_gpu/llama3_lora_pretrain.yaml b/examples/train_lora/llama3_lora_pretrain.yaml similarity index 84% rename from examples/lora_single_gpu/llama3_lora_pretrain.yaml rename to examples/train_lora/llama3_lora_pretrain.yaml index acb18ebf..839b3e51 100644 --- a/examples/lora_single_gpu/llama3_lora_pretrain.yaml +++ b/examples/train_lora/llama3_lora_pretrain.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: pt do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: c4_demo @@ -24,14 +24,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_lora_reward.yaml b/examples/train_lora/llama3_lora_reward.yaml similarity index 85% rename from examples/lora_single_gpu/llama3_lora_reward.yaml rename to examples/train_lora/llama3_lora_reward.yaml index 6bf2ca02..79559d19 100644 --- a/examples/lora_single_gpu/llama3_lora_reward.yaml +++ b/examples/train_lora/llama3_lora_reward.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: rm do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: dpo_en_demo @@ -25,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.00001 +learning_rate: 1.0e-5 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml similarity index 85% rename from examples/lora_single_gpu/llama3_lora_sft.yaml rename to examples/train_lora/llama3_lora_sft.yaml index 5492bc34..fe30c575 100644 --- a/examples/lora_single_gpu/llama3_lora_sft.yaml +++ b/examples/train_lora/llama3_lora_sft.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -25,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds0.yaml similarity index 86% rename from examples/lora_multi_npu/llama3_lora_sft_ds.yaml rename to examples/train_lora/llama3_lora_sft_ds0.yaml index 65ab6347..08b638e6 100644 --- a/examples/lora_multi_npu/llama3_lora_sft_ds.yaml +++ b/examples/train_lora/llama3_lora_sft_ds0.yaml @@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj - -### ddp -ddp_timeout: 180000000 +lora_target: all deepspeed: examples/deepspeed/ds_z0_config.json ### dataset @@ -29,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 2 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml b/examples/train_lora/llama3_lora_sft_ds3.yaml similarity index 86% rename from examples/lora_multi_gpu/llama3_lora_sft_ds.yaml rename to examples/train_lora/llama3_lora_sft_ds3.yaml index 6011896a..b7266d61 100644 --- a/examples/lora_multi_gpu/llama3_lora_sft_ds.yaml +++ b/examples/train_lora/llama3_lora_sft_ds3.yaml @@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj - -### ddp -ddp_timeout: 180000000 +lora_target: all deepspeed: examples/deepspeed/ds_z3_config.json ### dataset @@ -29,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 2 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/lora_single_gpu/llama3_preprocess.yaml b/examples/train_lora/llama3_preprocess.yaml similarity index 93% rename from examples/lora_single_gpu/llama3_preprocess.yaml rename to examples/train_lora/llama3_preprocess.yaml index 86dad37b..34bb9efc 100644 --- a/examples/lora_single_gpu/llama3_preprocess.yaml +++ b/examples/train_lora/llama3_preprocess.yaml @@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo diff --git a/examples/lora_single_gpu/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml similarity index 85% rename from examples/lora_single_gpu/llava1_5_lora_sft.yaml rename to examples/train_lora/llava1_5_lora_sft.yaml index 8e4226da..55ac31fa 100644 --- a/examples/lora_single_gpu/llava1_5_lora_sft.yaml +++ b/examples/train_lora/llava1_5_lora_sft.yaml @@ -6,7 +6,7 @@ visual_inputs: true stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: mllm_demo @@ -26,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml b/examples/train_qlora/llama3_lora_sft_aqlm.yaml similarity index 85% rename from examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml rename to examples/train_qlora/llama3_lora_sft_aqlm.yaml index d2658051..7b6767d5 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml +++ b/examples/train_qlora/llama3_lora_sft_aqlm.yaml @@ -5,7 +5,7 @@ model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16 stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -25,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml b/examples/train_qlora/llama3_lora_sft_awq.yaml similarity index 85% rename from examples/qlora_single_gpu/llama3_lora_sft_awq.yaml rename to examples/train_qlora/llama3_lora_sft_awq.yaml index ba6d8ea5..a2a26e4b 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_awq.yaml +++ b/examples/train_qlora/llama3_lora_sft_awq.yaml @@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -25,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml b/examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml similarity index 86% rename from examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml rename to examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml index a3db35ff..cc773991 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml +++ b/examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml @@ -6,7 +6,7 @@ quantization_bit: 4 stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -26,14 +26,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml b/examples/train_qlora/llama3_lora_sft_gptq.yaml similarity index 85% rename from examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml rename to examples/train_qlora/llama3_lora_sft_gptq.yaml index cc9a454e..ad3d854c 100644 --- a/examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml +++ b/examples/train_qlora/llama3_lora_sft_gptq.yaml @@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ stage: sft do_train: true finetuning_type: lora -lora_target: q_proj,v_proj +lora_target: all ### dataset dataset: identity,alpaca_en_demo @@ -25,14 +25,15 @@ overwrite_output_dir: true ### train per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -learning_rate: 0.0001 +learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine -warmup_steps: 0.1 +warmup_ratio: 0.1 fp16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 per_device_eval_batch_size: 1 -evaluation_strategy: steps +eval_strategy: steps eval_steps: 500 diff --git a/requirements.txt b/requirements.txt index f4a942e6..9e00555e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ -transformers>=4.37.2 -datasets>=2.14.3 -accelerate>=0.27.2 -peft>=0.10.0 -trl>=0.8.1 +transformers>=4.41.2 +datasets>=2.16.0 +accelerate>=0.30.1 +peft>=0.11.1 +trl>=0.8.6 gradio>=4.0.0 scipy einops sentencepiece +tiktoken protobuf uvicorn pydantic diff --git a/scripts/cal_flops.py b/scripts/cal_flops.py index ac87e0ab..32526d89 100644 --- a/scripts/cal_flops.py +++ b/scripts/cal_flops.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Calculates the flops of pre-trained models. -# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 -# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ +# Copyright 2024 Microsoft Corporation and the LlamaFactory team. +# +# This code is inspired by the Microsoft's DeepSpeed library. +# https://www.deepspeed.ai/tutorials/flops-profiler/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import fire import torch @@ -17,6 +30,10 @@ def calculate_flops( seq_length: int = 256, flash_attn: str = "auto", ): + r""" + Calculates the flops of pre-trained models. + Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 + """ with get_accelerator().device(0): chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py index bfa32cc9..ad6992cb 100644 --- a/scripts/cal_lr.py +++ b/scripts/cal_lr.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. -# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 -# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py +# Copyright 2024 imoneoi and the LlamaFactory team. +# +# This code is inspired by the imoneoi's OpenChat library. +# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from typing import Literal @@ -32,6 +45,10 @@ def calculate_lr( cutoff_len: int = 1024, # i.e. maximum input length during training is_mistral: bool = False, # mistral model uses a smaller learning rate, ): + r""" + Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. + Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 + """ model_args, data_args, training_args, _, _ = get_train_args( dict( stage=stage, diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 387b756c..fb503629 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -1,6 +1,17 @@ # coding=utf-8 -# Calculates the ppl on the dataset of the pre-trained models. -# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from dataclasses import dataclass @@ -56,6 +67,10 @@ def cal_ppl( max_samples: Optional[int] = None, train_on_prompt: bool = False, ): + r""" + Calculates the ppl on the dataset of the pre-trained models. + Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json + """ model_args, data_args, training_args, finetuning_args, _ = get_train_args( dict( stage=stage, diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py index 7739dcf0..4cdf01e6 100644 --- a/scripts/length_cdf.py +++ b/scripts/length_cdf.py @@ -1,6 +1,17 @@ # coding=utf-8 -# Calculates the distribution of the input lengths in the dataset. -# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import defaultdict @@ -19,6 +30,10 @@ def length_cdf( template: str = "default", interval: int = 1000, ): + r""" + Calculates the distribution of the input lengths in the dataset. + Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default + """ model_args, data_args, training_args, _, _ = get_train_args( dict( stage="sft", diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index 997b3496..17bf6fc2 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -1,7 +1,20 @@ # coding=utf-8 -# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. -# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 -# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py +# Copyright 2024 Tencent Inc. and the LlamaFactory team. +# +# This code is inspired by the Tencent's LLaMA-Pro library. +# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -37,6 +50,10 @@ def block_expansion( shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False, ): + r""" + Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. + Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 + """ config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) num_layers = getattr(config, "num_hidden_layers") setattr(config, "num_hidden_layers", num_layers + num_expand) @@ -103,11 +120,11 @@ def block_expansion( json.dump(index, f, indent=2, sort_keys=True) print("Model weights saved in {}".format(output_dir)) - print("Fine-tune this model with:") - print(" --model_name_or_path {} \\".format(output_dir)) - print(" --finetuning_type freeze \\") - print(" --freeze_trainable_layers {} \\".format(num_expand)) - print(" --use_llama_pro") + print("- Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("finetuning_type: freeze") + print("freeze_trainable_layers: {}".format(num_expand)) + print("use_llama_pro: true") if __name__ == "__main__": diff --git a/scripts/llamafy_baichuan2.py b/scripts/llamafy_baichuan2.py index 1ae58879..19284f5f 100644 --- a/scripts/llamafy_baichuan2.py +++ b/scripts/llamafy_baichuan2.py @@ -1,8 +1,17 @@ # coding=utf-8 -# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. -# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output -# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py -# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str): def llamafy_baichuan2( input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False ): + r""" + Converts the Baichuan2-7B model in the same format as LLaMA2-7B. + Usage: python llamafy_baichuan2.py --input_dir input --output_dir output + Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied + """ try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/scripts/llamafy_qwen.py b/scripts/llamafy_qwen.py index 69cf3e8e..e5b59483 100644 --- a/scripts/llamafy_qwen.py +++ b/scripts/llamafy_qwen.py @@ -1,7 +1,17 @@ # coding=utf-8 -# Converts the Qwen models in the same format as LLaMA2. -# Usage: python llamafy_qwen.py --input_dir input --output_dir output -# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os @@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str): def llamafy_qwen( input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False ): + r""" + Converts the Qwen models in the same format as LLaMA2. + Usage: python llamafy_qwen.py --input_dir input --output_dir output + Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied + """ try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py index 7f244316..b9506fa3 100644 --- a/scripts/loftq_init.py +++ b/scripts/loftq_init.py @@ -1,14 +1,25 @@ # coding=utf-8 -# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) -# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir -# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import fire -import torch -import torch.nn as nn from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer @@ -17,38 +28,21 @@ if TYPE_CHECKING: from transformers import PreTrainedModel -class Shell(nn.Module): - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = nn.Parameter(weight, requires_grad=False) - if bias is not None: - self.bias = nn.Parameter(bias, requires_grad=False) - - -def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: - for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}: - parent_name = ".".join(name.split(".")[:-1]) - child_name = name.split(".")[-1] - parent_module = model.get_submodule(parent_name) - child_module = getattr(parent_module, child_name) - base_layer = getattr(child_module, "base_layer") - weight = getattr(base_layer, "weight", None) - bias = getattr(base_layer, "bias", None) - setattr(parent_module, child_name, Shell(weight, bias)) - - print("Model unwrapped.") - - def quantize_loftq( model_name_or_path: str, - save_dir: str, - loftq_bits: Optional[int] = 4, - loftq_iter: Optional[int] = 1, - lora_alpha: Optional[int] = None, - lora_rank: Optional[int] = 16, - lora_target: Optional[str] = "q_proj,v_proj", - save_safetensors: Optional[bool] = False, + output_dir: str, + loftq_bits: int = 4, + loftq_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: str = "q_proj,v_proj", + save_safetensors: bool = True, ): + r""" + Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) + Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir + """ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) @@ -57,25 +51,34 @@ def quantize_loftq( inference_mode=True, r=lora_rank, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, - lora_dropout=0.1, + lora_dropout=lora_dropout, target_modules=[name.strip() for name in lora_target.split(",")], init_lora_weights="loftq", loftq_config=loftq_config, ) # Init LoftQ model - lora_model = get_peft_model(model, lora_config) - base_model: "PreTrainedModel" = lora_model.get_base_model() + print("Initializing LoftQ weights, it may be take several minutes, wait patiently.") + peft_model = get_peft_model(model, lora_config) + loftq_dir = os.path.join(output_dir, "loftq_init") # Save LoftQ model - setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) - setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) - lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) + setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir) + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again + peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(loftq_dir)) # Save base model - unwrap_model(base_model) - base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) - tokenizer.save_pretrained(save_dir) + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("- Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(loftq_dir)) + print("finetuning_type: lora") + print("quantization_bit: {}".format(loftq_bits)) if __name__ == "__main__": diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py new file mode 100644 index 00000000..10b81efc --- /dev/null +++ b/scripts/pissa_init.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING + +import fire +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +def quantize_pissa( + model_name_or_path: str, + output_dir: str, + pissa_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: str = "q_proj,v_proj", + save_safetensors: bool = True, +): + r""" + Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA) + Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir + """ + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=lora_rank, + lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, + lora_dropout=lora_dropout, + target_modules=[name.strip() for name in lora_target.split(",")], + init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter) + ) + + # Init PiSSA model + peft_model = get_peft_model(model, lora_config) + pissa_dir = os.path.join(output_dir, "pissa_init") + + # Save PiSSA model + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again + peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(pissa_dir)) + + # Save base model + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("- Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(pissa_dir)) + print("finetuning_type: lora") + print("pissa_init: false") + print("pissa_convert: true") + print("- and optionally with:") + print("quantization_bit: 4") + + +if __name__ == "__main__": + fire.Fire(quantize_pissa) diff --git a/tests/test_toolcall.py b/scripts/test_toolcall.py similarity index 78% rename from tests/test_toolcall.py rename to scripts/test_toolcall.py index d36e7fec..6f6fd06c 100644 --- a/tests/test_toolcall.py +++ b/scripts/test_toolcall.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import Sequence @@ -20,7 +35,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: def main(): client = OpenAI( - api_key="0", + api_key="{}".format(os.environ.get("API_KEY", "0")), base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), ) tools = [ diff --git a/setup.py b/setup.py index 4d948450..3d2ac921 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import re @@ -5,7 +19,7 @@ from setuptools import find_packages, setup def get_version(): - with open(os.path.join("src", "llamafactory", "cli.py"), "r", encoding="utf-8") as f: + with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: file_content = f.read() pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") (version,) = re.findall(pattern, file_content) @@ -21,18 +35,19 @@ def get_requires(): extra_require = { "torch": ["torch>=1.13.1"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], - "deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], + "deepspeed": ["deepspeed>=0.10.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"], - "vllm": ["vllm>=0.4.0"], + "vllm": ["vllm>=0.4.3"], "galore": ["galore-torch"], "badam": ["badam"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "qwen": ["tiktoken", "transformers_stream_generator"], + "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], - "quality": ["ruff"], + "dev": ["ruff", "pytest"], } diff --git a/src/api.py b/src/api.py index 3655e393..0f925497 100644 --- a/src/api.py +++ b/src/api.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import uvicorn diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index b889e268..9d732777 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -1,4 +1,18 @@ -# Level: api, webui > chat, eval, train > data, model > extras, hparams +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Level: api, webui > chat, eval, train > data, model > hparams > extras from .cli import VERSION diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index 21edab2f..c1264617 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from contextlib import asynccontextmanager from typing import Optional diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index b7a08f0b..a2074dbb 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -1,10 +1,27 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io import json +import os import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole from ..extras.logging import get_logger -from ..extras.packages import is_fastapi_available +from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify from .protocol import ( ChatCompletionMessage, @@ -25,7 +42,17 @@ if is_fastapi_available(): from fastapi import HTTPException, status +if is_pillow_available(): + from PIL import Image + + +if is_requests_available(): + import requests + + if TYPE_CHECKING: + from numpy.typing import NDArray + from ..chat import ChatModel from .protocol import ChatCompletionRequest, ScoreEvaluationRequest @@ -40,7 +67,9 @@ ROLE_MAPPING = { } -def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: +def _process_request( + request: "ChatCompletionRequest", +) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) if len(request.messages) == 0: @@ -49,12 +78,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s if request.messages[0].role == Role.SYSTEM: system = request.messages.pop(0).content else: - system = "" + system = None if len(request.messages) % 2 == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") input_messages = [] + image = None for i, message in enumerate(request.messages): if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") @@ -66,6 +96,21 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s arguments = message.tool_calls[0].function.arguments content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) + elif isinstance(message.content, list): + for input_item in message.content: + if input_item.type == "text": + input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) + else: + image_url = input_item.image_url.url + if image_url.startswith("data:image"): # base64 image + image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) + image_path = io.BytesIO(image_data) + elif os.path.isfile(image_url): # local file + image_path = open(image_url, "rb") + else: # web uri + image_path = requests.get(image_url, stream=True).raw + + image = Image.open(image_path).convert("RGB") else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) @@ -76,9 +121,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s except Exception: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: - tools = "" + tools = None - return input_messages, system, tools + return input_messages, system, tools, image def _create_stream_chat_completion_chunk( @@ -97,11 +142,12 @@ async def create_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> "ChatCompletionResponse": completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = _process_request(request) + input_messages, system, tools, image = _process_request(request) responses = await chat_model.achat( input_messages, system, tools, + image, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -145,7 +191,7 @@ async def create_stream_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> AsyncGenerator[str, None]: completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = _process_request(request) + input_messages, system, tools, image = _process_request(request) if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") @@ -159,6 +205,7 @@ async def create_stream_chat_completion_response( input_messages, system, tools, + image, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, diff --git a/src/llamafactory/api/common.py b/src/llamafactory/api/common.py index 5ad9a071..d1ac94de 100644 --- a/src/llamafactory/api/common.py +++ b/src/llamafactory/api/common.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from typing import TYPE_CHECKING, Any, Dict diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index 525fa6a7..a69132ea 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from enum import Enum, unique from typing import Any, Dict, List, Optional, Union @@ -56,9 +70,19 @@ class FunctionCall(BaseModel): function: Function +class ImageURL(BaseModel): + url: str + + +class MultimodalInputItem(BaseModel): + type: Literal["text", "image_url"] + text: Optional[str] = None + image_url: Optional[ImageURL] = None + + class ChatMessage(BaseModel): role: Role - content: Optional[str] = None + content: Optional[Union[str, List[MultimodalInputItem]]] = None tool_calls: Optional[List[FunctionCall]] = None diff --git a/src/llamafactory/chat/__init__.py b/src/llamafactory/chat/__init__.py index a1a79de6..07276d48 100644 --- a/src/llamafactory/chat/__init__.py +++ b/src/llamafactory/chat/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base_engine import BaseEngine from .chat_model import ChatModel diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 65b6c59c..92a51ebe 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 281ef0c1..fb800106 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -1,3 +1,20 @@ +# Copyright 2024 THUDM and the LlamaFactory team. +# +# This code is inspired by the THUDM's ChatGLM implementation. +# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 2148f8cd..a7ff7015 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import concurrent.futures import os @@ -8,6 +22,7 @@ import torch from transformers import GenerationConfig, TextIteratorStreamer from ..data import get_template_and_fix_tokenizer +from ..extras.logging import get_logger from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -23,6 +38,9 @@ if TYPE_CHECKING: from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +logger = get_logger(__name__) + + class HuggingfaceEngine(BaseEngine): def __init__( self, @@ -79,6 +97,7 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.bool) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) temperature: Optional[float] = input_kwargs.pop("temperature", None) @@ -92,7 +111,7 @@ class HuggingfaceEngine(BaseEngine): stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) if stop is not None: - raise ValueError("Stop parameter is not supported in Huggingface engine yet.") + logger.warning("Stop parameter is not supported in Huggingface engine yet.") generating_args = generating_args.copy() generating_args.update( @@ -132,6 +151,7 @@ class HuggingfaceEngine(BaseEngine): gen_kwargs = dict( inputs=inputs, + attention_mask=attention_mask, generation_config=GenerationConfig(**generating_args), logits_processor=get_logits_processor(), ) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 3310a864..d488a039 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -1,19 +1,37 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger -from ..extras.misc import get_device_count, infer_optim_dtype -from ..extras.packages import is_vllm_available +from ..extras.misc import get_device_count +from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5 from ..model import load_config, load_tokenizer -from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM +from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - from vllm.sequence import MultiModalData + + if is_vllm_version_greater_than_0_5(): + from vllm.multimodal.image import ImagePixelData + else: + from vllm.sequence import MultiModalData if TYPE_CHECKING: @@ -35,8 +53,6 @@ class VllmEngine(BaseEngine): generating_args: "GeneratingArguments", ) -> None: config = load_config(model_args) # may download model from ms hub - infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - infer_dtype = str(infer_dtype).split(".")[-1] self.can_generate = finetuning_args.stage == "sft" tokenizer_module = load_tokenizer(model_args) @@ -50,7 +66,7 @@ class VllmEngine(BaseEngine): "model": model_args.model_name_or_path, "trust_remote_code": True, "download_dir": model_args.cache_dir, - "dtype": infer_dtype, + "dtype": model_args.infer_dtype, "max_model_len": model_args.vllm_maxlen, "tensor_parallel_size": get_device_count() or 1, "gpu_memory_utilization": model_args.vllm_gpu_util, @@ -70,7 +86,6 @@ class VllmEngine(BaseEngine): engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_feature_size"] = self.image_feature_size if getattr(config, "is_yi_vl_derived_model", None): - # bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828 import vllm.model_executor.models.llava logger.info("Detected Yi-VL model, applying projector patch.") @@ -109,7 +124,10 @@ class VllmEngine(BaseEngine): if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + if is_vllm_version_greater_than_0_5(): + multi_modal_data = ImagePixelData(image=pixel_values) + else: # TODO: remove vllm 0.4.3 support + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) else: multi_modal_data = None @@ -158,12 +176,10 @@ class VllmEngine(BaseEngine): ) result_generator = self.model.generate( - prompt=None, + inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, sampling_params=sampling_params, request_id=request_id, - prompt_token_ids=prompt_ids, lora_request=self.lora_request, - multi_modal_data=multi_modal_data, ) return result_generator diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index f9b63ded..c7f136b3 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -1,9 +1,30 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import subprocess import sys from enum import Enum, unique +from . import launcher from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval +from .extras.env import VERSION, print_env +from .extras.logging import get_logger +from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -23,8 +44,6 @@ USAGE = ( + "-" * 70 ) -VERSION = "0.7.2.dev0" - WELCOME = ( "-" * 58 + "\n" @@ -37,11 +56,14 @@ WELCOME = ( + "-" * 58 ) +logger = get_logger(__name__) + @unique class Command(str, Enum): API = "api" CHAT = "chat" + ENV = "env" EVAL = "eval" EXPORT = "export" TRAIN = "train" @@ -57,12 +79,35 @@ def main(): run_api() elif command == Command.CHAT: run_chat() + elif command == Command.ENV: + print_env() elif command == Command.EVAL: run_eval() elif command == Command.EXPORT: export_model() elif command == Command.TRAIN: - run_exp() + force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] + if force_torchrun or get_device_count() > 1: + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) + logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) + subprocess.run( + ( + "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " + "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" + ).format( + nnodes=os.environ.get("NNODES", "1"), + node_rank=os.environ.get("RANK", "0"), + nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), + master_addr=master_addr, + master_port=master_port, + file_name=launcher.__file__, + args=" ".join(sys.argv[1:]), + ), + shell=True, + ) + else: + run_exp() elif command == Command.WEBDEMO: run_web_demo() elif command == Command.WEBUI: diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 44887d24..307853bc 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,16 +1,30 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .data_utils import Role, split_dataset from .loader import get_dataset -from .template import Template, get_template_and_fix_tokenizer, templates -from .utils import Role, split_dataset +from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", - "get_dataset", - "Template", - "get_template_and_fix_tokenizer", - "templates", "Role", "split_dataset", + "get_dataset", + "TEMPLATES", + "Template", + "get_template_and_fix_tokenizer", ] diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 2a382c60..299bdca3 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Union @@ -5,11 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import Features from ..extras.logging import get_logger -from .utils import Role +from .data_utils import Role if TYPE_CHECKING: from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments from .parser import DatasetAttr @@ -175,7 +190,10 @@ def convert_sharegpt( def align_dataset( - dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: @@ -208,7 +226,7 @@ def align_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Converting format of dataset", ) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 1dc8dd8d..e4859ff5 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Any, Dict, Sequence diff --git a/src/llamafactory/data/utils.py b/src/llamafactory/data/data_utils.py similarity index 84% rename from src/llamafactory/data/utils.py rename to src/llamafactory/data/data_utils.py index 9b313112..cc9761b1 100644 --- a/src/llamafactory/data/utils.py +++ b/src/llamafactory/data/data_utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List, Tuple, Union diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 0cd3d6c1..590e682b 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import re from abc import ABC, abstractmethod diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 48d28f1d..f44ef5de 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -1,24 +1,38 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import os import sys from typing import TYPE_CHECKING, Literal, Optional, Union +import numpy as np from datasets import load_dataset, load_from_disk from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger from ..extras.misc import has_tokenized_data from .aligner import align_dataset +from .data_utils import merge_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func from .template import get_template_and_fix_tokenizer -from .utils import merge_dataset if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments, ModelArguments from .parser import DatasetAttr @@ -31,6 +45,7 @@ def load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None @@ -61,9 +76,9 @@ def load_single_dataset( raise ValueError("File {} not found.".format(local_path)) if data_path is None: - raise ValueError("File extension must be txt, csv, json or jsonl.") + raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) else: - raise NotImplementedError + raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) if dataset_attr.load_from == "ms_hub": try: @@ -106,18 +121,30 @@ def load_single_dataset( if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if data_args.max_samples is not None: # truncate dataset - num_samples = min(data_args.max_samples, len(dataset)) - dataset = dataset.select(range(num_samples)) + if dataset_attr.num_samples is not None and not data_args.streaming: + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) - return align_dataset(dataset, dataset_attr, data_args) + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." + dataset = dataset.select(indexes) + logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) + + if data_args.max_samples is not None: # truncate dataset + max_samples = min(data_args.max_samples, len(dataset)) + dataset = dataset.select(range(max_samples)) + + return align_dataset(dataset, dataset_attr, data_args, training_args) def get_dataset( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "kto"], + stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: @@ -144,7 +171,8 @@ def get_dataset( if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) + all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) + dataset = merge_dataset(all_datasets, data_args, training_args) with training_args.main_process_first(desc="pre-process dataset"): @@ -156,7 +184,7 @@ def get_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) @@ -166,7 +194,7 @@ def get_dataset( if training_args.should_save: dataset.save_to_disk(data_args.tokenized_path) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) - logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) sys.exit(0) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 679f8ad6..4bebcd68 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from dataclasses import dataclass @@ -20,11 +34,12 @@ class DatasetAttr: """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False """ extra configs """ subset: Optional[str] = None folder: Optional[str] = None - ranking: bool = False - formatting: Literal["alpaca", "sharegpt"] = "alpaca" + num_samples: Optional[int] = None """ common columns """ system: Optional[str] = None tools: Optional[str] = None @@ -102,10 +117,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("ranking", dataset_info[name], default=False) - dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("num_samples", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 336257ca..9a8b97f3 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple @@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments from .template import Template @@ -23,7 +36,7 @@ if TYPE_CHECKING: def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "kto"], + stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 1aaff0ab..219ab353 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -1,13 +1,26 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -16,6 +29,55 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _encode_feedback_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + kl_response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int], List[int], List[int], bool]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + if response[0]["content"]: # desired example + kto_tag = True + messages = prompt + [response[0]] + else: # undesired example + kto_tag = False + messages = prompt + [response[1]] + + if kl_response[0]["content"]: + kl_messages = prompt + [kl_response[0]] + else: + kl_messages = prompt + [kl_response[1]] + + prompt_ids, response_ids = template.encode_oneturn( + tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + _, kl_response_ids = template.encode_oneturn( + tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + + if template.efficient_eos: + response_ids += [tokenizer.eos_token_id] + kl_response_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + input_ids = prompt_ids + response_ids + labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids + kl_input_ids = prompt_ids + kl_response_ids + kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids + + return input_ids, labels, kl_input_ids, kl_labels, kto_tag + + def preprocess_feedback_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -45,50 +107,17 @@ def preprocess_feedback_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - if examples["response"][i][0]["content"]: # desired example - kto_tag = True - messages = examples["prompt"][i] + [examples["response"][i][0]] - else: # undesired example - kto_tag = False - messages = examples["prompt"][i] + [examples["response"][i][1]] - - if kl_response[i][0]["content"]: - kl_messages = examples["prompt"][i] + [kl_response[i][0]] - else: - kl_messages = examples["prompt"][i] + [kl_response[i][1]] - - prompt_ids, response_ids = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, + input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + kl_response=kl_response[i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, ) - _, kl_response_ids = template.encode_oneturn( - tokenizer, - kl_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - response_ids += [tokenizer.eos_token_id] - kl_response_ids += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - - input_ids = prompt_ids + response_ids - labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids - kl_input_ids = prompt_ids + kl_response_ids - kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) diff --git a/src/llamafactory/data/processors/mm_utils.py b/src/llamafactory/data/processors/mm_utils.py deleted file mode 100644 index abc7c4b2..00000000 --- a/src/llamafactory/data/processors/mm_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import TYPE_CHECKING, List, Sequence - -from ...extras.packages import is_pillow_available - - -if is_pillow_available(): - from PIL import Image - - -if TYPE_CHECKING: - from numpy.typing import NDArray - from PIL.Image import Image as ImageObject - from transformers import ProcessorMixin - from transformers.image_processing_utils import BaseImageProcessor - - -def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": - # process visual inputs (currently only supports a single image) - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) - return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) - - -def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: - # get paligemma token type ids for computing loss - image_seq_length = getattr(processor, "image_seq_length") - return [0] * image_seq_length + [1] * (input_len - image_seq_length) diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index 69dab34a..b2939348 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -1,13 +1,26 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -16,6 +29,44 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _encode_pairwise_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int], List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + chosen_messages = prompt + [response[0]] + rejected_messages = prompt + [response[1]] + prompt_ids, chosen_ids = template.encode_oneturn( + tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + _, rejected_ids = template.encode_oneturn( + tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + + chosen_input_ids = prompt_ids + chosen_ids + chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids + rejected_input_ids = prompt_ids + rejected_ids + rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids + + return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels + + def preprocess_pairwise_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -43,40 +94,16 @@ def preprocess_pairwise_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] - rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] - prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, - chosen_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, + chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, ) - _, rejected_ids = template.encode_oneturn( - tokenizer, - rejected_messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - - if template.efficient_eos: - chosen_ids += [tokenizer.eos_token_id] - rejected_ids += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - - chosen_input_ids = prompt_ids + chosen_ids - chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids - rejected_input_ids = prompt_ids + rejected_ids - rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids model_inputs["chosen_input_ids"].append(chosen_input_ids) model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) model_inputs["chosen_labels"].append(chosen_labels) diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 3de0d1ac..67d6009b 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -1,9 +1,26 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from itertools import chain from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer from ...hparams import DataArguments @@ -12,13 +29,14 @@ def preprocess_pretrain_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled - text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] + eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token + text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]] if not data_args.packing: if data_args.template == "gemma": text_examples = [tokenizer.bos_token + example for example in text_examples] - result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) + result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True) else: tokenized_examples = tokenizer(text_examples, add_special_tokens=False) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py new file mode 100644 index 00000000..93df0cd5 --- /dev/null +++ b/src/llamafactory/data/processors/processor_utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +from typing import TYPE_CHECKING, List, Sequence + +from ...extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + + +if TYPE_CHECKING: + from numpy.typing import NDArray + from PIL.Image import Image as ImageObject + from transformers import ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +def search_for_fit(numbers: Sequence[int], capacity: int) -> int: + r""" + Finds the index of largest number that fits into the knapsack with the given capacity. + """ + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: + r""" + An efficient greedy algorithm with binary search for the knapsack problem. + """ + numbers.sort() # sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = search_for_fit(numbers, remaining_capacity) + if index == -1: + break # no more numbers fit in this knapsack + + remaining_capacity -= numbers[index] # update the remaining capacity + current_knapsack.append(numbers.pop(index)) # add the number to knapsack + + knapsacks.append(current_knapsack) + + return knapsacks + + +def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + r""" + Processes visual inputs. (currently only supports a single image) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) + + +def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: + r""" + Gets paligemma token type ids for computing loss. + """ + image_seq_length = getattr(processor, "image_seq_length") + return [0] * image_seq_length + [1] * (input_len - image_seq_length) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index b119aa22..eb5ffb1a 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,13 +1,27 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -16,6 +30,48 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + messages = prompt + response + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + encoded_pairs = template.encode_multiturn( + tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + def preprocess_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -36,41 +92,16 @@ def preprocess_supervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - messages = examples["prompt"][i] + examples["response"][i] - input_ids, labels = [], [] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids += [image_token_id] * getattr(processor, "image_seq_length") - labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - - for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) @@ -90,41 +121,55 @@ def preprocess_packed_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - input_ids, labels = [], [] + valid_num = 0 + batch_input_ids, batch_labels = [], [] + lengths = [] + length2indexes = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - messages = examples["prompt"][i] + examples["response"][i] - for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=None, + data_args=data_args, + ) + length = len(input_ids) + if length > data_args.cutoff_len: + logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) + else: + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + valid_num += 1 - input_ids += source_ids + target_ids - labels += source_mask + target_ids + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_labels = [], [] + for length in knapsack: + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] + if len(packed_input_ids) < data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) + model_inputs["labels"].append(packed_labels) return model_inputs diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 6a9f9460..75ad4d51 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -1,13 +1,26 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.logging import get_logger -from ..utils import Role -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from ..data_utils import Role +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template @@ -16,6 +29,37 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _encode_unsupervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + if len(response) == 1: + messages = prompt + response + else: + messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] + + input_ids, labels = template.encode_oneturn( + tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + if template.efficient_eos: + labels += [tokenizer.eos_token_id] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + + return input_ids, labels + + def preprocess_unsupervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -35,30 +79,16 @@ def preprocess_unsupervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - if len(examples["response"][i]) == 1: - messages = examples["prompt"][i] + examples["response"][i] - else: - messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] - - input_ids, labels = template.encode_oneturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, + input_ids, labels = _encode_unsupervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, ) - - if template.efficient_eos: - labels += [tokenizer.eos_token_id] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids - model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 979390ce..786c679f 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1,9 +1,23 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from ..extras.logging import get_logger +from .data_utils import Role, infer_max_len from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter -from .utils import Role, infer_max_len if TYPE_CHECKING: @@ -196,7 +210,7 @@ class Llama2Template(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) -templates: Dict[str, Template] = {} +TEMPLATES: Dict[str, Template] = {} def _register_template( @@ -248,7 +262,7 @@ def _register_template( default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() - templates[name] = template_class( + TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, format_assistant=format_assistant or default_assistant_formatter, format_system=format_system or default_user_formatter, @@ -348,9 +362,9 @@ def get_template_and_fix_tokenizer( name: Optional[str] = None, ) -> Template: if name is None: - template = templates["empty"] # placeholder + template = TEMPLATES["empty"] # placeholder else: - template = templates.get(name, None) + template = TEMPLATES.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) @@ -544,8 +558,13 @@ _register_template( ) ] ), - format_system=EmptyFormatter(slots=[{"bos_token"}]), - force_system=True, + format_system=StringFormatter( + slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"] + ), + default_system=( + "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users " + "by providing thorough responses. You are trained by Cohere." + ), ) @@ -653,6 +672,19 @@ _register_template( ) +_register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["[gMASK]{{content}}"]), + format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + force_system=True, +) + + _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), @@ -682,17 +714,8 @@ _register_template( _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]), format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), - default_system=( - "You are a helpful, respectful and honest assistant. " - "Always answer as helpfully as possible, while being safe. " - "Your answers should not include any harmful, unethical, " - "racist, sexist, toxic, dangerous, or illegal content. " - "Please ensure that your responses are socially unbiased and positive in nature.\n\n" - "If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information." - ), ) @@ -742,7 +765,6 @@ _register_template( _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), - format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), force_system=True, ) @@ -751,12 +773,28 @@ _register_template( _register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), - format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), force_system=True, ) +_register_template( + name="openchat-3.6", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + stop_words=["<|eot_id|>"], + replace_eos=True, + force_system=True, +) + + _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), @@ -807,6 +845,15 @@ _register_template( ) +_register_template( + name="telechat", + format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), + format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]), + stop_words=["<_end>"], + replace_eos=True, +) + + _register_template( name="vicuna", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), @@ -857,6 +904,7 @@ _register_template( _register_template( name="yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|im_end|>"], replace_eos=True, diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index 192f4815..d3140793 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -1,4 +1,41 @@ -# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the Dan's test library. +# https://github.com/hendrycks/test/blob/master/evaluate_flan.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2020 Dan Hendrycks +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. import inspect import json @@ -26,9 +63,7 @@ class Evaluator: self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.eval_template = get_eval_template(self.eval_args.lang) - self.choice_inputs = [ - self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES - ] + self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] @torch.inference_mode() def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py index a4a6ef0e..7d524e7c 100644 --- a/src/llamafactory/eval/template.py +++ b/src/llamafactory/eval/template.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Dict, List, Sequence, Tuple @@ -10,7 +24,6 @@ class EvalTemplate: system: str choice: str answer: str - prefix: str def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: r""" @@ -42,8 +55,8 @@ class EvalTemplate: eval_templates: Dict[str, "EvalTemplate"] = {} -def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: - eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) +def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: + eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) def get_eval_template(name: str) -> "EvalTemplate": @@ -56,8 +69,7 @@ _register_eval_template( name="en", system="The following are multiple choice questions (with answers) about {subject}.\n\n", choice="\n{choice}. {content}", - answer="\nAnswer: ", - prefix=" ", + answer="\nAnswer:", ) @@ -66,5 +78,4 @@ _register_eval_template( system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", choice="\n{choice}. {content}", answer="\n答案:", - prefix=" ", ) diff --git a/src/llamafactory/extras/callbacks.py b/src/llamafactory/extras/callbacks.py index 637b786d..0dff6a69 100644 --- a/src/llamafactory/extras/callbacks.py +++ b/src/llamafactory/extras/callbacks.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import logging import os @@ -170,12 +184,14 @@ class LogCallback(TrainerCallback): percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, + throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)), + total_tokens=state.num_input_tokens_seen, ) logs = {k: v for k, v in logs.items() if v is not None} if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): logger.info( - "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( - logs["loss"], logs["learning_rate"], logs["epoch"] + "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( + logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] ) ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 087612fc..73a9969d 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1,14 +1,39 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import OrderedDict, defaultdict from enum import Enum from typing import Dict, Optional +from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME +from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME + + +CHECKPOINT_NAMES = { + SAFE_ADAPTER_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +} CHOICES = ["A", "B", "C", "D"] DATA_CONFIG = "dataset_info.json" -DEFAULT_MODULE = defaultdict(str) - DEFAULT_TEMPLATE = defaultdict(str) FILEEXT2TYPE = { @@ -24,11 +49,13 @@ IGNORE_INDEX = -100 LAYERNORM_NAMES = {"norm", "ln"} +LLAMABOARD_CONFIG = "llamaboard_config.yaml" + METHODS = ["full", "freeze", "lora"] -MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"] +MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} -PEFT_METHODS = ["lora"] +PEFT_METHODS = {"lora"} RUNNING_LOG = "running_log.txt" @@ -36,10 +63,10 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUPPORTED_MODELS = OrderedDict() -TRAINER_CONFIG = "trainer_config.yaml" - TRAINER_LOG = "trainer_log.jsonl" +TRAINING_ARGS = "training_args.yaml" + TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", @@ -49,9 +76,9 @@ TRAINING_STAGES = { "Pre-Training": "pt", } -STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] +STAGES_USE_PAIR_DATA = {"rm", "dpo"} -SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} V_HEAD_WEIGHTS_NAME = "value_head.bin" @@ -67,7 +94,6 @@ class DownloadSource(str, Enum): def register_model_group( models: Dict[str, Dict[DownloadSource, str]], - module: Optional[str] = None, template: Optional[str] = None, vision: bool = False, ) -> None: @@ -78,14 +104,25 @@ def register_model_group( else: assert prefix == name.split("-")[0], "prefix should be identical." SUPPORTED_MODELS[name] = path - if module is not None: - DEFAULT_MODULE[prefix] = module if template is not None: DEFAULT_TEMPLATE[prefix] = template if vision: VISION_MODELS.add(prefix) +register_model_group( + models={ + "Aya-23-8B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-8B", + }, + "Aya-23-35B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-35B", + }, + }, + template="cohere", +) + + register_model_group( models={ "Baichuan-7B-Base": { @@ -101,7 +138,6 @@ register_model_group( DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat", }, }, - module="W_pack", template="baichuan", ) @@ -125,7 +161,6 @@ register_model_group( DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", }, }, - module="W_pack", template="baichuan2", ) @@ -145,7 +180,6 @@ register_model_group( DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1", }, }, - module="query_key_value", ) @@ -164,7 +198,6 @@ register_model_group( DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt", }, }, - module="query_key_value", ) @@ -203,7 +236,6 @@ register_model_group( DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b", } }, - module="query_key_value", template="chatglm2", ) @@ -219,7 +251,6 @@ register_model_group( DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b", }, }, - module="query_key_value", template="chatglm3", ) @@ -255,6 +286,36 @@ register_model_group( ) +register_model_group( + models={ + "CodeGemma-7B": { + DownloadSource.DEFAULT: "google/codegemma-7b", + }, + "CodeGemma-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it", + }, + "CodeGemma-1.1-2B": { + DownloadSource.DEFAULT: "google/codegemma-1.1-2b", + }, + "CodeGemma-1.1-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "Codestral-22B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1", + }, + }, + template="mistral", +) + + register_model_group( models={ "CommandR-35B-Chat": { @@ -288,7 +349,6 @@ register_model_group( DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct", }, }, - module="Wqkv", template="dbrx", ) @@ -407,7 +467,6 @@ register_model_group( DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat", }, }, - module="query_key_value", template="falcon", ) @@ -443,21 +502,20 @@ register_model_group( register_model_group( models={ - "CodeGemma-7B": { - DownloadSource.DEFAULT: "google/codegemma-7b", + "GLM-4-9B": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b", }, - "CodeGemma-7B-Chat": { - DownloadSource.DEFAULT: "google/codegemma-7b-it", - DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it", + "GLM-4-9B-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", }, - "CodeGemma-1.1-2B": { - DownloadSource.DEFAULT: "google/codegemma-1.1-2b", - }, - "CodeGemma-1.1-7B-Chat": { - DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it", + "GLM-4-9B-1M-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m", }, }, - template="gemma", + template="glm4", ) @@ -503,7 +561,6 @@ register_model_group( DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", }, }, - module="wqkv", template="intern2", ) @@ -525,7 +582,6 @@ register_model_group( DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B", } }, - module="qkv_proj", ) @@ -626,6 +682,21 @@ register_model_group( ) +register_model_group( + models={ + "MiniCPM-2B-SFT-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16", + }, + "MiniCPM-2B-DPO-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16", + }, + }, + template="cpm", +) + + register_model_group( models={ "Mistral-7B-v0.1": { @@ -707,6 +778,16 @@ register_model_group( ) +register_model_group( + models={ + "OpenChat3.6-8B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522", + } + }, + template="openchat-3.6", +) + + register_model_group( models={ "Orion-14B-Base": { @@ -802,7 +883,6 @@ register_model_group( DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", }, }, - module="qkv_proj", template="phi", ) @@ -874,7 +954,6 @@ register_model_group( DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", }, }, - module="c_attn", template="qwen", ) @@ -1030,6 +1109,89 @@ register_model_group( ) +register_model_group( + models={ + "Qwen2-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B", + }, + "Qwen2-1.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B", + }, + "Qwen2-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B", + }, + "Qwen2-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B", + }, + "Qwen2-MoE-57B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B", + }, + "Qwen2-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", + }, + "Qwen2-1.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", + }, + "Qwen2-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", + }, + "Qwen2-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct", + }, + "Qwen2-MoE-57B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct", + }, + "Qwen2-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ", + }, + "Qwen2-1.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-1.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ", + }, + "Qwen2-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8", + }, + "Qwen2-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ", + }, + "Qwen2-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8", + }, + "Qwen2-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ", + }, + "Qwen2-MoE-57B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + }, + }, + template="qwen", +) + + register_model_group( models={ "SOLAR-10.7B": { @@ -1072,6 +1234,25 @@ register_model_group( ) +register_model_group( + models={ + "TeleChat-7B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/telechat-7B", + DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", + }, + "TeleChat-12B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", + }, + "TeleChat-12B-v2-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2", + }, + }, + template="telechat", +) + + register_model_group( models={ "Vicuna1.5-7B-Chat": { diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py new file mode 100644 index 00000000..586c24c0 --- /dev/null +++ b/src/llamafactory/extras/env.py @@ -0,0 +1,72 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +import accelerate +import datasets +import peft +import torch +import transformers +import trl +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + + +VERSION = "0.8.2.dev0" + + +def print_env() -> None: + info = { + "`llamafactory` version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Datasets version": datasets.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + "TRL version": trl.__version__, + } + + if is_torch_cuda_available(): + info["PyTorch version"] += " (GPU)" + info["GPU type"] = torch.cuda.get_device_name() + + if is_torch_npu_available(): + info["PyTorch version"] += " (NPU)" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann + + try: + import deepspeed # type: ignore + + info["DeepSpeed version"] = deepspeed.__version__ + except Exception: + pass + + try: + import bitsandbytes + + info["Bitsandbytes version"] = bitsandbytes.__version__ + except Exception: + pass + + try: + import vllm + + info["vLLM version"] = vllm.__version__ + except Exception: + pass + + print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 430b8a48..67622212 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import sys diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 0dc07d28..93153b3e 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import gc import os from typing import TYPE_CHECKING, Dict, Tuple @@ -8,6 +22,7 @@ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTr from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_NAME, + is_safetensors_available, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_mps_available, @@ -20,6 +35,11 @@ from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from .logging import get_logger +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import save_file + + _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() try: _is_bf16_available = is_torch_bf16_gpu_available() @@ -61,11 +81,11 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") - require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") - require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") - require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") - require_version("trl>=0.8.2", "To fix: pip install trl>=0.8.2") + require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") + require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") + require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") + require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") + require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: @@ -114,9 +134,6 @@ def fix_valuehead_checkpoint( return if safe_serialization: - from safetensors import safe_open - from safetensors.torch import save_file - path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} @@ -165,13 +182,15 @@ def get_current_device() -> torch.device: def get_device_count() -> int: r""" - Gets the number of available GPU devices. + Gets the number of available GPU or NPU devices. """ - if not torch.cuda.is_available(): + if is_torch_npu_available(): + return torch.npu.device_count() + elif is_torch_cuda_available(): + return torch.cuda.device_count() + else: return 0 - return torch.cuda.device_count() - def get_logits_processor() -> "LogitsProcessorList": r""" @@ -194,6 +213,13 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: return torch.float32 +def is_gpu_or_npu_available() -> bool: + r""" + Checks if the GPU or NPU is available. + """ + return is_torch_npu_available() or is_torch_cuda_available() + + def has_tokenized_data(path: os.PathLike) -> bool: r""" Checks if the path has a tokenized dataset. @@ -203,12 +229,17 @@ def has_tokenized_data(path: os.PathLike) -> bool: def torch_gc() -> None: r""" - Collects GPU memory. + Collects GPU or NPU memory. """ gc.collect() - if torch.cuda.is_available(): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_cuda_available(): torch.cuda.empty_cache() - torch.cuda.ipc_collect() def try_download_model_from_ms(model_args: "ModelArguments") -> str: diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 4c9e6492..0a84a293 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -1,5 +1,23 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import importlib.metadata import importlib.util +from functools import lru_cache from typing import TYPE_CHECKING from packaging import version @@ -24,10 +42,6 @@ def is_fastapi_available(): return _is_package_available("fastapi") -def is_flash_attn2_available(): - return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0") - - def is_galore_available(): return _is_package_available("galore_torch") @@ -36,18 +50,10 @@ def is_gradio_available(): return _is_package_available("gradio") -def is_jieba_available(): - return _is_package_available("jieba") - - def is_matplotlib_available(): return _is_package_available("matplotlib") -def is_nltk_available(): - return _is_package_available("nltk") - - def is_pillow_available(): return _is_package_available("PIL") @@ -60,10 +66,6 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") -def is_sdpa_available(): - return _get_package_version("torch") > version.parse("2.1.1") - - def is_starlette_available(): return _is_package_available("sse_starlette") @@ -74,3 +76,8 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") + + +@lru_cache +def is_vllm_version_greater_than_0_5(): + return _get_package_version("vllm") >= version.parse("0.5.0") diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py index dea23bbe..596d55e7 100644 --- a/src/llamafactory/extras/ploting.py +++ b/src/llamafactory/extras/ploting.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import math import os diff --git a/src/llamafactory/hparams/__init__.py b/src/llamafactory/hparams/__init__.py index d1ee98dd..cfe448c1 100644 --- a/src/llamafactory/hparams/__init__.py +++ b/src/llamafactory/hparams/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1e0cd08c..39290e21 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field from typing import Literal, Optional diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py index 5a05f6f6..a7f221ca 100644 --- a/src/llamafactory/hparams/evaluation_args.py +++ b/src/llamafactory/hparams/evaluation_args.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from dataclasses import dataclass, field from typing import Literal, Optional diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 05b246ae..1ef46eca 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -1,5 +1,19 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import List, Literal, Optional @dataclass @@ -24,12 +38,7 @@ class FreezeArguments: "help": ( "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. " "Use commas to separate multiple modules. " - "Use `all` to specify all the available modules. " - "LLaMA choices: [`mlp`, `self_attn`], " - "BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], " - "Qwen choices: [`mlp`, `attn`], " - "InternLM2 choices: [`feed_forward`, `attention`], " - "Others choices: the same as LLaMA." + "Use `all` to specify all the available modules." ) }, ) @@ -79,13 +88,7 @@ class LoraArguments: "help": ( "Name(s) of target modules to apply LoRA. " "Use commas to separate multiple modules. " - "Use `all` to specify all the linear modules. " - "LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], " - "BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], " - "Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], " - "Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], " - "InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], " - "Others choices: the same as LLaMA." + "Use `all` to specify all the linear modules." ) }, ) @@ -105,6 +108,18 @@ class LoraArguments: default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=4, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) create_new_adapter: bool = field( default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, @@ -311,6 +326,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, ) + freeze_vision_tower: bool = field( + default=True, + metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + ) + train_mm_proj_only: bool = field( + default=False, + metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + ) plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, @@ -322,19 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA return [item.strip() for item in arg.split(",")] return arg - self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules) - self.freeze_extra_modules = split_arg(self.freeze_extra_modules) - self.lora_alpha = self.lora_alpha or self.lora_rank * 2 - self.lora_target = split_arg(self.lora_target) - self.additional_target = split_arg(self.additional_target) - self.galore_target = split_arg(self.galore_target) + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + self.lora_target: List[str] = split_arg(self.lora_target) + self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + self.galore_target: List[str] = split_arg(self.galore_target) + self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]) assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - self.use_ref_model = self.pref_loss not in ["orpo", "simpo"] - if self.stage == "ppo" and self.reward_model is None: raise ValueError("`reward_model` is necessary for PPO training.") @@ -345,7 +368,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") if self.use_llama_pro and self.finetuning_type == "full": - raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.") + raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): raise ValueError("Cannot use LoRA with GaLore or BAdam together.") @@ -354,4 +377,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA raise ValueError("Cannot use GaLore with BAdam together.") if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": - raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.") + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.pissa_convert and self.finetuning_type != "lora": + raise ValueError("`pissa_convert` is only valid for LoRA training.") + + if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model): + raise ValueError("Cannot use PiSSA for current training stage.") + + if self.train_mm_proj_only and self.finetuning_type != "full": + raise ValueError("`train_mm_proj_only` is only valid for full training.") diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 0ee17d1a..7ebb4eed 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 650d1c22..996e9130 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,5 +1,28 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union + +from typing_extensions import Self + + +if TYPE_CHECKING: + import torch @dataclass @@ -15,7 +38,16 @@ class ModelArguments: ) adapter_name_or_path: Optional[str] = field( default=None, - metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, + metadata={ + "help": ( + "Path to the adapter weight or identifier from huggingface.co/models. " + "Use commas to separate multiple adapters." + ) + }, + ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, ) cache_dir: Optional[str] = field( default=None, @@ -35,7 +67,7 @@ class ModelArguments: ) new_special_tokens: Optional[str] = field( default=None, - metadata={"help": "Special tokens to be added into the tokenizer."}, + metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, ) model_revision: str = field( default="main", @@ -101,13 +133,17 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) infer_backend: Literal["huggingface", "vllm"] = field( default="huggingface", metadata={"help": "Backend engine used at inference."}, ) vllm_maxlen: int = field( default=2048, - metadata={"help": "Maximum input length of the vLLM engine."}, + metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, ) vllm_gpu_util: float = field( default=0.9, @@ -118,7 +154,7 @@ class ModelArguments: metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, ) vllm_max_lora_rank: int = field( - default=8, + default=32, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, ) offload_folder: str = field( @@ -129,6 +165,10 @@ class ModelArguments: default=True, metadata={"help": "Whether or not to use KV cache in generation."}, ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, @@ -145,9 +185,9 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) - export_device: Literal["cpu", "cuda"] = field( + export_device: Literal["cpu", "auto"] = field( default="cpu", - metadata={"help": "The device used in model export, use cuda to avoid addmm errors."}, + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, ) export_quantization_bit: Optional[int] = field( default=None, @@ -179,9 +219,9 @@ class ModelArguments: ) def __post_init__(self): - self.compute_dtype = None - self.device_map = None - self.model_max_length = None + self.compute_dtype: Optional["torch.dtype"] = None + self.device_map: Optional[Union[str, Dict[str, Any]]] = None + self.model_max_length: Optional[int] = None if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") @@ -203,3 +243,13 @@ class ModelArguments: def to_dict(self) -> Dict[str, Any]: return asdict(self) + + @classmethod + def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + arg_dict = old_arg.to_dict() + arg_dict.update(**kwargs) + new_arg = cls(**arg_dict) + new_arg.compute_dtype = old_arg.compute_dtype + new_arg.device_map = old_arg.device_map + new_arg.model_max_length = old_arg.model_max_length + return new_arg diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index fe108657..f922bbfd 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import sys @@ -6,11 +23,13 @@ from typing import Any, Dict, Optional, Tuple import torch import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint +from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.versions import require_version -from ..extras.constants import TRAINER_CONFIG +from ..extras.constants import CHECKPOINT_NAMES from ..extras.logging import get_logger from ..extras.misc import check_dependencies, get_current_device from .data_args import DataArguments @@ -64,10 +83,16 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Adapter is only valid for the LoRA method.") + if model_args.use_unsloth and is_deepspeed_zero3_enabled(): + raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + if model_args.quantization_bit is not None: if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") + if finetuning_args.pissa_init: + raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") + if model_args.resize_vocab: raise ValueError("Cannot resize embedding layers of a quantized model.") @@ -90,7 +115,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0") + require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch") @@ -158,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("PPO only accepts wandb or tensorboard logger.") + if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: + raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") + if training_args.max_steps == -1 and data_args.streaming: raise ValueError("Please specify `max_steps` in streaming mode.") @@ -167,9 +195,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.do_train and model_args.quantization_device_map == "auto": raise ValueError("Cannot use device map for quantized models in training.") - if finetuning_args.use_dora and model_args.use_unsloth: - raise ValueError("Unsloth does not support DoRA.") - if finetuning_args.pure_bf16: if not is_torch_bf16_gpu_available(): raise ValueError("This device does not support `pure_bf16`.") @@ -180,16 +205,25 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if ( finetuning_args.use_galore and finetuning_args.galore_layerwise - and training_args.parallel_mode.value == "distributed" + and training_args.parallel_mode == ParallelMode.DISTRIBUTED ): raise ValueError("Distributed training does not support layer-wise GaLore.") +<<<<<<< HEAD # if ( # finetuning_args.use_badam # and finetuning_args.badam_mode == "layer" # and training_args.parallel_mode.value == "distributed" # ): # raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") +======= + if ( + finetuning_args.use_badam + and finetuning_args.badam_mode == "layer" + and training_args.parallel_mode == ParallelMode.DISTRIBUTED + ): + raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") +>>>>>>> upstream/main if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") @@ -229,7 +263,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Post-process training arguments if ( - training_args.parallel_mode.value == "distributed" + training_args.parallel_mode == ParallelMode.DISTRIBUTED and training_args.ddp_find_unused_parameters is None and finetuning_args.finetuning_type == "lora" ): @@ -252,17 +286,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and can_resume_from_checkpoint ): last_checkpoint = get_last_checkpoint(training_args.output_dir) - files = os.listdir(training_args.output_dir) - if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG): + if last_checkpoint is None and any( + os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES + ): raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") if last_checkpoint is not None: training_args.resume_from_checkpoint = last_checkpoint - logger.info( - "Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( - training_args.resume_from_checkpoint - ) - ) + logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) + logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") if ( finetuning_args.stage in ["rm", "ppo"] @@ -291,7 +323,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: training_args.local_rank, training_args.device, training_args.n_gpu, - training_args.parallel_mode.value == "distributed", + training_args.parallel_mode == ParallelMode.DISTRIBUTED, str(model_args.compute_dtype), ) ) diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py new file mode 100644 index 00000000..65e0b68f --- /dev/null +++ b/src/llamafactory/launcher.py @@ -0,0 +1,23 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.train.tuner import run_exp + + +def launch(): + run_exp() + + +if __name__ == "__main__": + launch() diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py index 88f666c8..4abbaa1b 100644 --- a/src/llamafactory/model/__init__.py +++ b/src/llamafactory/model/__init__.py @@ -1,12 +1,26 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .loader import load_config, load_model, load_tokenizer -from .utils.misc import find_all_linear_modules -from .utils.valuehead import load_valuehead_params +from .model_utils.misc import find_all_linear_modules +from .model_utils.valuehead import load_valuehead_params __all__ = [ "load_config", "load_model", "load_tokenizer", - "load_valuehead_params", "find_all_linear_modules", + "load_valuehead_params", ] diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index f37f3bbb..a8f3a256 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import TYPE_CHECKING @@ -7,9 +21,9 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger -from .utils.misc import find_all_linear_modules, find_expanded_modules -from .utils.quantization import QuantizationMethod -from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model +from .model_utils.misc import find_all_linear_modules, find_expanded_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model if TYPE_CHECKING: @@ -21,6 +35,238 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _setup_full_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Full") + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + if model_args.visual_inputs and finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + + for name, param in model.named_parameters(): + if not any(forbidden_module in name for forbidden_module in forbidden_modules): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + +def _setup_freeze_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Freeze") + if model_args.visual_inputs: + config = model.config.text_config + else: + config = model.config + + num_layers = ( + getattr(config, "num_hidden_layers", None) + or getattr(config, "num_layers", None) + or getattr(config, "n_layer", None) + ) + if not num_layers: + raise ValueError("Current model does not support freeze tuning.") + + if finetuning_args.use_llama_pro: + if num_layers % finetuning_args.freeze_trainable_layers != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( + num_layers, finetuning_args.freeze_trainable_layers + ) + ) + + stride = num_layers // finetuning_args.freeze_trainable_layers + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers) + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers)) + + hidden_modules = set() + non_hidden_modules = set() + for name, _ in model.named_parameters(): + if ".0." in name: + hidden_modules.add(name.split(".0.")[-1].split(".")[0]) + elif ".1." in name: # MoD starts from layer 1 + hidden_modules.add(name.split(".1.")[-1].split(".")[0]) + + if re.search(r"\.\d+\.", name) is None: + non_hidden_modules.add(name.split(".")[-2]) + + trainable_layers = [] + for module_name in finetuning_args.freeze_trainable_modules: + if module_name != "all" and module_name not in hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)) + ) + + for idx in trainable_layer_ids: + trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) + + if finetuning_args.freeze_extra_modules: + for module_name in finetuning_args.freeze_extra_modules: + if module_name not in non_hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules)) + ) + + trainable_layers.append(module_name) + + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + for name, param in model.named_parameters(): + if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( + forbidden_module in name for forbidden_module in forbidden_modules + ): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) + + +def _setup_lora_tuning( + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> "PeftModel": + if is_trainable: + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + + adapter_to_resume = None + + if model_args.adapter_name_or_path is not None: + is_mergeable = True + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + is_mergeable = False + + if is_deepspeed_zero3_enabled(): + assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + is_mergeable = False + + if model_args.use_unsloth: + assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + is_mergeable = False + + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + adapter_to_merge = model_args.adapter_name_or_path[:-1] + adapter_to_resume = model_args.adapter_name_or_path[-1] + else: + adapter_to_merge = model_args.adapter_name_or_path + + init_kwargs = { + "subfolder": model_args.adapter_folder, + "offload_folder": model_args.offload_folder, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + model = model.merge_and_unload() + + if len(adapter_to_merge) > 0: + logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) + + if adapter_to_resume is not None: # resume lora training + if model_args.use_unsloth: + model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + else: + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) + + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + + if is_trainable and adapter_to_resume is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + target_modules = finetuning_args.lora_target + + if finetuning_args.use_llama_pro: + target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) + + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + + if ( + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + raise ValueError("DoRA is not compatible with PTQ-quantized models.") + + if model_args.resize_vocab and finetuning_args.additional_target is None: + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + module_names = set() + for name, module in model.named_modules(): + if module in [input_embeddings, output_embeddings]: + module_names.add(name.split(".")[-1]) + + finetuning_args.additional_target = module_names + logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + "use_dora": finetuning_args.use_dora, + "modules_to_save": finetuning_args.additional_target, + } + + if model_args.use_unsloth: + model = get_unsloth_peft_model(model, model_args, peft_kwargs) + else: + if finetuning_args.pissa_init: + if finetuning_args.pissa_iter == -1: + logger.info("Using PiSSA initialization.") + peft_kwargs["init_lora_weights"] = "pissa" + else: + logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) + peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + if is_trainable and cast_trainable_params_to_fp32: + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + return model + + def init_adapter( config: "PretrainedConfig", model: "PreTrainedModel", @@ -35,194 +281,27 @@ def init_adapter( Note that the trainable parameters must be cast to float32. """ + if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantized models can only be used for the LoRA tuning.") - if (not is_trainable) and model_args.adapter_name_or_path is None: - logger.info("Adapter is not found at evaluation, load the base model.") - return model - - if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): - raise ValueError("You can only use lora for quantized models.") - - if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: + if not is_trainable: + cast_trainable_params_to_fp32 = False + elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") cast_trainable_params_to_fp32 = False else: logger.info("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True - if finetuning_args.finetuning_type == "full" and is_trainable: - logger.info("Fine-tuning method: Full") - if cast_trainable_params_to_fp32: - model = model.float() - - if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model - model.vision_tower.requires_grad_(False) - - if finetuning_args.finetuning_type == "freeze" and is_trainable: - logger.info("Fine-tuning method: Freeze") - num_layers = ( - getattr(model.config, "num_hidden_layers", None) - or getattr(model.config, "num_layers", None) - or getattr(model.config, "n_layer", None) + if finetuning_args.finetuning_type == "full": + _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "freeze": + _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "lora": + model = _setup_lora_tuning( + config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 ) - if not num_layers: - raise ValueError("Current model does not support freeze tuning.") - - if finetuning_args.use_llama_pro: - if num_layers % finetuning_args.freeze_trainable_layers != 0: - raise ValueError( - "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( - num_layers, finetuning_args.freeze_trainable_layers - ) - ) - - stride = num_layers // finetuning_args.freeze_trainable_layers - trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) - elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0 - trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers) - else: # fine-tuning the first n layers if num_layer_trainable < 0 - trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers)) - - hidden_modules = set() - non_hidden_modules = set() - for name, _ in model.named_parameters(): - if ".0." in name: - hidden_modules.add(name.split(".0.")[-1].split(".")[0]) - elif ".1." in name: # MoD starts from layer 1 - hidden_modules.add(name.split(".1.")[-1].split(".")[0]) - - if re.search(r"\.\d+\.", name) is None: - non_hidden_modules.add(name.split(".")[-2]) - - trainable_layers = [] - for module_name in finetuning_args.freeze_trainable_modules: - if module_name != "all" and module_name not in hidden_modules: - raise ValueError( - "Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)) - ) - - for idx in trainable_layer_ids: - trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) - - if finetuning_args.freeze_extra_modules: - for module_name in finetuning_args.freeze_extra_modules: - if module_name not in non_hidden_modules: - raise ValueError( - "Module {} is not found, please choose from {}".format( - module_name, ", ".join(non_hidden_modules) - ) - ) - - trainable_layers.append(module_name) - - for name, param in model.named_parameters(): - if any(trainable_layer in name for trainable_layer in trainable_layers): - if cast_trainable_params_to_fp32: - param.data = param.data.to(torch.float32) - else: - param.requires_grad_(False) - - if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model - model.vision_tower.requires_grad_(False) - - logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) - - if finetuning_args.finetuning_type == "lora": - logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) - adapter_to_resume = None - - if model_args.adapter_name_or_path is not None: - is_mergeable = True - if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable - assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." - is_mergeable = False - - if is_deepspeed_zero3_enabled(): - assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." - is_mergeable = False - - if model_args.use_unsloth: - assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." - is_mergeable = False - - if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): - adapter_to_merge = model_args.adapter_name_or_path[:-1] - adapter_to_resume = model_args.adapter_name_or_path[-1] - else: - adapter_to_merge = model_args.adapter_name_or_path - - for adapter in adapter_to_merge: - model: "LoraModel" = PeftModel.from_pretrained( - model, adapter, offload_folder=model_args.offload_folder - ) - model = model.merge_and_unload() - - if len(adapter_to_merge) > 0: - logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) - - if adapter_to_resume is not None: # resume lora training - if model_args.use_unsloth: - model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) - else: - model = PeftModel.from_pretrained( - model, - adapter_to_resume, - is_trainable=is_trainable, - offload_folder=model_args.offload_folder, - ) - - if is_trainable and adapter_to_resume is None: # create new lora weights while training - if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": - target_modules = find_all_linear_modules(model) - else: - target_modules = finetuning_args.lora_target - - if finetuning_args.use_llama_pro: - target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) - - if ( - finetuning_args.use_dora - and getattr(model, "quantization_method", None) is not None - and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES - ): - raise ValueError("DoRA is not compatible with PTQ-quantized models.") - - if model_args.resize_vocab and finetuning_args.additional_target is None: - input_embeddings = model.get_input_embeddings() - output_embeddings = model.get_output_embeddings() - module_names = set() - for name, module in model.named_modules(): - if module in [input_embeddings, output_embeddings]: - module_names.add(name.split(".")[-1]) - - finetuning_args.additional_target = module_names - logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) - - peft_kwargs = { - "r": finetuning_args.lora_rank, - "target_modules": target_modules, - "lora_alpha": finetuning_args.lora_alpha, - "lora_dropout": finetuning_args.lora_dropout, - "use_rslora": finetuning_args.use_rslora, - "modules_to_save": finetuning_args.additional_target, - } - - if model_args.use_unsloth: - model = get_unsloth_peft_model(model, model_args, peft_kwargs) - else: - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - use_dora=finetuning_args.use_dora, - **peft_kwargs, - ) - model = get_peft_model(model, lora_config) - - if cast_trainable_params_to_fp32: - for param in filter(lambda p: p.requires_grad, model.parameters()): - param.data = param.data.to(torch.float32) - - if model_args.adapter_name_or_path is not None: - logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + else: + raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) return model diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 49b347d5..69cccd93 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer @@ -6,11 +20,11 @@ from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger from ..extras.misc import count_parameters, try_download_model_from_ms from .adapter import init_adapter +from .model_utils.misc import register_autoclass +from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model +from .model_utils.unsloth import load_unsloth_pretrained_model +from .model_utils.valuehead import load_valuehead_params from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model -from .utils.misc import register_autoclass -from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model -from .utils.unsloth import load_unsloth_pretrained_model -from .utils.valuehead import load_valuehead_params if TYPE_CHECKING: @@ -131,6 +145,8 @@ def load_model( model = load_mod_pretrained_model(**init_kwargs) elif model_args.visual_inputs: model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) + elif model_args.train_from_scratch: + model = AutoModelForCausalLM.from_config(config) else: model = AutoModelForCausalLM.from_pretrained(**init_kwargs) diff --git a/src/llamafactory/model/utils/__init__.py b/src/llamafactory/model/model_utils/__init__.py similarity index 100% rename from src/llamafactory/model/utils/__init__.py rename to src/llamafactory/model/model_utils/__init__.py diff --git a/src/llamafactory/model/utils/attention.py b/src/llamafactory/model/model_utils/attention.py similarity index 70% rename from src/llamafactory/model/utils/attention.py rename to src/llamafactory/model/model_utils/attention.py index b52ddc86..8ff3807b 100644 --- a/src/llamafactory/model/utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -1,7 +1,22 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + from ...extras.logging import get_logger -from ...extras.packages import is_flash_attn2_available, is_sdpa_available if TYPE_CHECKING: @@ -21,13 +36,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa": - if not is_sdpa_available(): + if not is_torch_sdpa_available(): logger.warning("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" elif model_args.flash_attn == "fa2": - if not is_flash_attn2_available(): + if not is_flash_attn_2_available(): logger.warning("FlashAttention-2 is not installed.") return diff --git a/src/llamafactory/model/utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py similarity index 82% rename from src/llamafactory/model/utils/checkpointing.py rename to src/llamafactory/model/model_utils/checkpointing.py index e0657be8..f5314125 100644 --- a/src/llamafactory/model/utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -1,3 +1,21 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and PEFT library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from functools import partial from types import MethodType @@ -68,7 +86,6 @@ def prepare_model_for_training( (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) add the upcasting of the lm_head in fp32 - Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 """ if model_args.upcast_layernorm: logger.info("Upcasting layernorm weights in float32.") diff --git a/src/llamafactory/model/utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py similarity index 76% rename from src/llamafactory/model/utils/embedding.py rename to src/llamafactory/model/model_utils/embedding.py index 357c9cc0..3ff79828 100644 --- a/src/llamafactory/model/utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from contextlib import nullcontext from typing import TYPE_CHECKING @@ -15,7 +29,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None: +def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: embedding_dim = embed_weight.size(1) avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) diff --git a/src/llamafactory/model/utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py similarity index 91% rename from src/llamafactory/model/utils/longlora.py rename to src/llamafactory/model/model_utils/longlora.py index c8dc52f5..af30bd50 100644 --- a/src/llamafactory/model/utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -1,3 +1,22 @@ +# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. +# +# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# This code is also inspired by the original LongLoRA implementation. +# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import TYPE_CHECKING, Optional, Tuple @@ -96,7 +115,8 @@ def llama_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -181,11 +201,9 @@ def llama_flash_attention_2_forward( query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - else: - groupsz = q_len attn_output: torch.Tensor = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate ) if getattr(self.config, "group_size_ratio", None) and self.training: # shift back @@ -194,7 +212,8 @@ def llama_flash_attention_2_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -293,7 +312,8 @@ def llama_sdpa_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -303,7 +323,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2") + require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/utils/misc.py b/src/llamafactory/model/model_utils/misc.py similarity index 64% rename from src/llamafactory/model/utils/misc.py rename to src/llamafactory/model/model_utils/misc.py index eca68866..a2812228 100644 --- a/src/llamafactory/model/utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -1,9 +1,20 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, List -import torch - from ...extras.logging import get_logger -from .quantization import QuantizationMethod if TYPE_CHECKING: @@ -13,29 +24,28 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: +def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: r""" Finds all available modules to apply lora or galore. """ - quantization_method = getattr(model, "quantization_method", None) - if quantization_method is None: - linear_cls = torch.nn.Linear - elif quantization_method == QuantizationMethod.BITS_AND_BYTES: - import bitsandbytes as bnb + forbidden_modules = {"lm_head"} - linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt - else: - raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) - - output_layer_names = ["lm_head"] if model.config.model_type == "chatglm": - output_layer_names.append("output_layer") + forbidden_modules.add("output_layer") elif model.config.model_type == "internlm2": - output_layer_names.append("output") + forbidden_modules.add("output") + elif model.config.model_type in ["llava", "paligemma"]: + forbidden_modules.add("multi_modal_projector") + + if freeze_vision_tower: + forbidden_modules.add("vision_tower") module_names = set() for name, module in model.named_modules(): - if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names): + if any(forbidden_module in name for forbidden_module in forbidden_modules): + continue + + if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: module_names.add(name.split(".")[-1]) logger.info("Found linear modules: {}".format(",".join(module_names))) diff --git a/src/llamafactory/model/utils/mod.py b/src/llamafactory/model/model_utils/mod.py similarity index 58% rename from src/llamafactory/model/utils/mod.py rename to src/llamafactory/model/model_utils/mod.py index 5708a1a8..ec73af00 100644 --- a/src/llamafactory/model/utils/mod.py +++ b/src/llamafactory/model/model_utils/mod.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING from ...extras.constants import MOD_SUPPORTED_MODELS diff --git a/src/llamafactory/model/utils/moe.py b/src/llamafactory/model/model_utils/moe.py similarity index 66% rename from src/llamafactory/model/utils/moe.py rename to src/llamafactory/model/model_utils/moe.py index e554e45a..5c7473aa 100644 --- a/src/llamafactory/model/utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -1,5 +1,20 @@ -from typing import TYPE_CHECKING +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Sequence + +import torch from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version @@ -10,6 +25,13 @@ if TYPE_CHECKING: from ...hparams import ModelArguments +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, leaf_modules) + + def add_z3_leaf_module(model: "PreTrainedModel") -> None: r""" Sets module as a leaf module to skip partitioning in deepspeed zero3. @@ -17,33 +39,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: if not is_deepspeed_zero3_enabled(): return - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") - from deepspeed.utils import set_z3_leaf_modules # type: ignore - if getattr(model.config, "model_type", None) == "dbrx": from transformers.models.dbrx.modeling_dbrx import DbrxFFN - set_z3_leaf_modules(model, [DbrxFFN]) + _set_z3_leaf_modules(model, [DbrxFFN]) if getattr(model.config, "model_type", None) == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock - set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) if getattr(model.config, "model_type", None) == "jetmoe": from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE - set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) if getattr(model.config, "model_type", None) == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) if getattr(model.config, "model_type", None) == "qwen2moe": from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: diff --git a/src/llamafactory/model/utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py similarity index 89% rename from src/llamafactory/model/utils/quantization.py rename to src/llamafactory/model/model_utils/quantization.py index 161ad5aa..0a0fca34 100644 --- a/src/llamafactory/model/utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Optimum library. +# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import random from enum import Enum, unique @@ -35,11 +52,12 @@ class QuantizationMethod(str, Enum): AWQ = "awq" AQLM = "aqlm" QUANTO = "quanto" + EETQ = "eetq" + HQQ = "hqq" def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: r""" - Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 """ if os.path.isfile(model_args.export_quantization_dataset): diff --git a/src/llamafactory/model/utils/rope.py b/src/llamafactory/model/model_utils/rope.py similarity index 67% rename from src/llamafactory/model/utils/rope.py rename to src/llamafactory/model/model_utils/rope.py index 93ab8929..88303c4d 100644 --- a/src/llamafactory/model/utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -1,3 +1,21 @@ +# Copyright 2024 LMSYS and the LlamaFactory team. +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# This code is inspired by the LMSYS's FastChat library. +# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from typing import TYPE_CHECKING diff --git a/src/llamafactory/model/utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py similarity index 83% rename from src/llamafactory/model/utils/unsloth.py rename to src/llamafactory/model/model_utils/unsloth.py index 8a16409d..9cfaec61 100644 --- a/src/llamafactory/model/utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, Optional from ...extras.logging import get_logger diff --git a/src/llamafactory/model/utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py similarity index 70% rename from src/llamafactory/model/utils/valuehead.py rename to src/llamafactory/model/model_utils/valuehead.py index d813729e..9ab3d45a 100644 --- a/src/llamafactory/model/utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict import torch @@ -23,6 +37,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. """ kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + err_text = "" try: from safetensors import safe_open @@ -31,16 +46,16 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> with safe_open(vhead_file, framework="pt", device="cpu") as f: return {key: f.get_tensor(key) for key in f.keys()} except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) + err_text = str(err) try: vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) return torch.load(vhead_file, map_location="cpu") except Exception as err: - logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) + err_text = str(err) - logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) - logger.info("Ignore these messages if you are not resuming the training of a value head model.") + logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) + logger.info("Ignore the above message if you are not resuming the training of a value head model.") return None diff --git a/src/llamafactory/model/utils/visual.py b/src/llamafactory/model/model_utils/visual.py similarity index 82% rename from src/llamafactory/model/utils/visual.py rename to src/llamafactory/model/model_utils/visual.py index c8260b7f..700bf470 100644 --- a/src/llamafactory/model/utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Tuple import torch diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1a8ce607..053516e4 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from types import MethodType from typing import TYPE_CHECKING, Any, Dict @@ -10,15 +24,15 @@ from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger from ..extras.misc import infer_optim_dtype -from .utils.attention import configure_attn_implementation, print_attn_implementation -from .utils.checkpointing import prepare_model_for_training -from .utils.embedding import resize_embedding_layer -from .utils.longlora import configure_longlora -from .utils.moe import add_z3_leaf_module, configure_moe -from .utils.quantization import configure_quantization -from .utils.rope import configure_rope -from .utils.valuehead import prepare_valuehead_model -from .utils.visual import autocast_projector_dtype, configure_visual_model +from .model_utils.attention import configure_attn_implementation, print_attn_implementation +from .model_utils.checkpointing import prepare_model_for_training +from .model_utils.embedding import resize_embedding_layer +from .model_utils.longlora import configure_longlora +from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.quantization import configure_quantization +from .model_utils.rope import configure_rope +from .model_utils.valuehead import prepare_valuehead_model +from .model_utils.visual import autocast_projector_dtype, configure_visual_model if TYPE_CHECKING: @@ -44,7 +58,10 @@ def patch_config( is_trainable: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 - model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + if model_args.infer_dtype == "auto": + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + else: + model_args.compute_dtype = getattr(torch, model_args.infer_dtype) if is_torch_npu_available(): use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] @@ -79,7 +96,7 @@ def patch_config( if "device_map" not in init_kwargs and model_args.device_map: init_kwargs["device_map"] = model_args.device_map - if init_kwargs["device_map"] == "auto": + if init_kwargs.get("device_map", None) == "auto": init_kwargs["offload_folder"] = model_args.offload_folder diff --git a/src/llamafactory/train/dpo/__init__.py b/src/llamafactory/train/dpo/__init__.py index 43fe9420..9ce0d089 100644 --- a/src/llamafactory/train/dpo/__init__.py +++ b/src/llamafactory/train/dpo/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_dpo diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index f3c2443c..9928d0bc 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -1,3 +1,22 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType @@ -7,10 +26,10 @@ import torch import torch.nn.functional as F from transformers import Trainer from trl import DPOTrainer -from trl.trainer.utils import disable_dropout_in_model +from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -61,6 +80,8 @@ class CustomDPOTrainer(DPOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -69,6 +90,10 @@ class CustomDPOTrainer(DPOTrainer): self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() + + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -88,22 +113,13 @@ class CustomDPOTrainer(DPOTrainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) - def sft_loss(self, batch: Dict[str, "torch.Tensor"], chosen_logits: "torch.FloatTensor") -> "torch.Tensor": - r""" - Computes supervised cross-entropy loss of given labels under the given logits. - - Returns: - A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. - """ - batch_size = batch["input_ids"].size(0) // 2 - chosen_labels, _ = batch["labels"].split(batch_size, dim=0) - chosen_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) - return -chosen_logps - def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r""" Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. @@ -155,9 +171,9 @@ class CustomDPOTrainer(DPOTrainer): def concatenated_forward( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: r""" - Computes the sum log probabilities of the labels under the given logits if loss_type != IPO. + Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. Otherwise the average log probabilities. """ @@ -166,20 +182,18 @@ class CustomDPOTrainer(DPOTrainer): all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) - all_logps = self.get_batch_logps( - logits=all_logits, - labels=batch["labels"], - average_log_prob=(self.loss_type in ["ipo", "orpo", "simpo"]), - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) + if self.loss_type in ["ipo", "orpo", "simpo"]: + all_logps = all_logps / valid_length + batch_size = batch["input_ids"].size(0) // 2 chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits + chosen_length, _ = valid_length.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length def compute_reference_log_probs( - self, batch: Dict[str, "torch.Tensor"] + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: r""" Computes log probabilities of the reference model. @@ -188,19 +202,14 @@ class CustomDPOTrainer(DPOTrainer): return None, None if self.ref_model is None: - ref_model = self.model - ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() with torch.no_grad(), ref_context: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - ) = self.concatenated_forward(ref_model, batch) + reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch) return reference_chosen_logps, reference_rejected_logps @@ -219,16 +228,17 @@ class CustomDPOTrainer(DPOTrainer): policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, + policy_chosen_logps_avg, ) = self.concatenated_forward(model, batch) - reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch) + reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) - sft_loss = self.sft_loss(batch, policy_chosen_logits) # compute chosen_logps with masks + sft_loss = -policy_chosen_logps_avg if self.ftx_gamma > 1e-6: losses += self.ftx_gamma * sft_loss diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 61a3e2f0..431b5285 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional @@ -7,7 +22,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push, create_ref_model +from ..trainer_utils import create_modelcard_and_push, create_ref_model from .trainer import CustomDPOTrainer diff --git a/src/llamafactory/train/kto/__init__.py b/src/llamafactory/train/kto/__init__.py index 34c7905a..a1900368 100644 --- a/src/llamafactory/train/kto/__init__.py +++ b/src/llamafactory/train/kto/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_kto diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 096fd935..91d68975 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -1,18 +1,37 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch from transformers import Trainer from trl import KTOTrainer -from trl.trainer.utils import disable_dropout_in_model +from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: + import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -59,6 +78,8 @@ class CustomKTOTrainer(KTOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -67,6 +88,7 @@ class CustomKTOTrainer(KTOTrainer): self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -84,73 +106,74 @@ class CustomKTOTrainer(KTOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) - def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor": - r""" - Computes supervised cross-entropy loss of given labels under the given logits. - - Returns: - A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. - """ - all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) - return -all_logps - def forward( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: - with torch.no_grad(): - kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]} - if "pixel_values" in batch: - kl_model_inputs["pixel_values"] = batch["pixel_values"] - - if "kl_token_type_ids" in batch: - kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"] - - kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) - - model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" + ) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Runs forward pass and computes the log probabilities. + """ + batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + model_inputs = { + "input_ids": batch["{}input_ids".format(prefix)], + "attention_mask": batch["{}attention_mask".format(prefix)], + } if "pixel_values" in batch: model_inputs["pixel_values"] = batch["pixel_values"] - if "token_type_ids" in batch: - model_inputs["token_type_ids"] = batch["token_type_ids"] + if "{}token_type_ids".format(prefix) in batch: + model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] - target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) + logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) - target_logps = self.get_batch_logps( - logits=target_logits, - labels=batch["labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) + return logps, logps / valid_length - kl_logps = self.get_batch_logps( - logits=kl_logits, - labels=batch["kl_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + def concatenated_forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + target_logps, target_logps_avg = self.forward(model, batch) + with torch.no_grad(): + kl_logps, _ = self.forward(model, batch, prefix="kl_") if len(target_logps) != len(batch["kto_tags"]): raise ValueError("Mismatched shape of inputs and labels.") - chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]] - rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]] + chosen_logps = target_logps[batch["kto_tags"]] + rejected_logps = target_logps[~batch["kto_tags"]] + chosen_logps_avg = target_logps_avg[batch["kto_tags"]] + return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg - chosen_logps = target_logps[chosen_idx, ...] - rejected_logps = target_logps[rejected_idx, ...] + def compute_reference_log_probs( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes log probabilities of the reference model. + """ + if self.ref_model is None: + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() + else: + ref_model = self.ref_model + ref_context = nullcontext() - chosen_logits = target_logits[chosen_idx, ...] - rejected_logits = target_logits[rejected_idx, ...] + with torch.no_grad(), ref_context: + reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( + ref_model, batch + ) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps + return reference_chosen_logps, reference_rejected_logps, reference_kl_logps def get_batch_loss_metrics( self, @@ -161,31 +184,12 @@ class CustomKTOTrainer(KTOTrainer): Computes the DPO loss and other metrics for the given batch of inputs for train or test. """ metrics = {} - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - _, - policy_kl_logps, - ) = self.forward(model, batch) - - with torch.no_grad(): - if self.ref_model is None: - ref_model = self.model - ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() - else: - ref_model = self.ref_model - ref_context = nullcontext() - - with ref_context: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_kl_logps, - ) = self.forward(ref_model, batch) - + policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( + self.concatenated_forward(model, batch) + ) + reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( + model, batch + ) losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( policy_chosen_logps, policy_rejected_logps, @@ -197,8 +201,8 @@ class CustomKTOTrainer(KTOTrainer): losses = losses.nanmean() if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale - sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]]) - losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"]) + sft_loss = -policy_chosen_logps_avg + losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 26dc770c..8182a184 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, List, Optional from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset @@ -5,7 +22,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push, create_ref_model +from ..trainer_utils import create_modelcard_and_push, create_ref_model from .trainer import CustomKTOTrainer diff --git a/src/llamafactory/train/ppo/__init__.py b/src/llamafactory/train/ppo/__init__.py index d17336d5..161f6f5d 100644 --- a/src/llamafactory/train/ppo/__init__.py +++ b/src/llamafactory/train/ppo/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_ppo diff --git a/src/llamafactory/train/ppo/utils.py b/src/llamafactory/train/ppo/ppo_utils.py similarity index 52% rename from src/llamafactory/train/ppo/utils.py rename to src/llamafactory/train/ppo/ppo_utils.py index e6bdb89c..05c40946 100644 --- a/src/llamafactory/train/ppo/utils.py +++ b/src/llamafactory/train/ppo/ppo_utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from contextlib import nullcontext from typing import TYPE_CHECKING, Dict, List, Literal, Optional @@ -8,15 +22,19 @@ from transformers.integrations import is_deepspeed_zero3_enabled from ...extras.packages import is_requests_available -if TYPE_CHECKING: - from transformers import PreTrainedModel - from trl import AutoModelForCausalLMWithValueHead - if is_requests_available(): import requests +if TYPE_CHECKING: + from transformers import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: + r""" + Gets reward scores from the API server. + """ headers = {"Content-Type": "application/json"} payload = {"model": "model", "messages": messages} response = requests.post(server_url, json=payload, headers=headers) @@ -25,25 +43,33 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch. def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: + r""" + Replaces the default/reward modules in the model. The model is already unwrapped. + """ + v_head_layer = model.v_head.summary if is_deepspeed_zero3_enabled(): import deepspeed # type: ignore - params = [model.v_head.summary.weight, model.v_head.summary.bias] + params = [v_head_layer.weight, v_head_layer.bias] context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) else: context_maybe_zero3 = nullcontext() + model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active with context_maybe_zero3: if target == "reward": # save default head temporarily - setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone()) - setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone()) + setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone()) + setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) - model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active - model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone() - model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone() + device = v_head_layer.weight.device + v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) + v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: + r""" + Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ layer_norm_params = {} for name, param in model.named_parameters(): if param.data.dtype == torch.float32: @@ -54,6 +80,9 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: + r""" + Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). + """ for name, param in model.named_parameters(): if name in layernorm_params: param.data = layernorm_params[name] diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 985664b7..df4a37be 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -1,10 +1,29 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import os import sys +import warnings from types import MethodType -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch +from accelerate.utils import DistributedDataParallelKwargs from tqdm import tqdm from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState from transformers.optimization import get_scheduler @@ -13,12 +32,13 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from trl import PPOConfig, PPOTrainer from trl.core import PPODecorators, logprobs_from_logits +from trl.models.utils import unwrap_model_for_generation from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor -from ..utils import create_custom_optimzer, create_custom_scheduler -from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm if TYPE_CHECKING: @@ -78,6 +98,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): project_kwargs={"logging_dir": training_args.logging_dir}, ) + # Add deepspeed config + ppo_config.accelerator_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) + ] + if training_args.deepspeed_plugin is not None: + ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin + # Create optimizer and scheduler if training_args.max_steps > 0: num_training_steps = training_args.max_steps @@ -114,15 +141,20 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.state = TrainerState() self.control = TrainerControl() - self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( - self.accelerator.state, "deepspeed_plugin" - ) + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm" + + self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype) + warnings.simplefilter("ignore") # remove gc warnings on ref model + if finetuning_args.reward_model_type == "full": if self.is_deepspeed_enabled: if not ( @@ -183,7 +215,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info(" Total training steps = {}".format(max_steps)) logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0])) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) loss_meter = AverageMeter() reward_meter = AverageMeter() @@ -196,29 +227,21 @@ class CustomPPOTrainer(PPOTrainer, Trainer): dataiter = iter(self.dataloader) batch = next(dataiter) - # Cast to inference mode - unwrapped_model.gradient_checkpointing_disable() - unwrapped_model.config.use_cache = True - self.model.eval() - # Get inputs + self.model.eval() self.tokenizer.padding_side = "right" # change padding side queries, responses, rewards = [], [], [] for idx in range(0, self.config.batch_size, self.config.mini_batch_size): mini_batch_queries, mini_batch_responses = self.get_inputs( batch[idx : idx + self.config.mini_batch_size] ) - mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model) + mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) queries.extend(mini_batch_queries) responses.extend(mini_batch_responses) rewards.extend(mini_batch_rewards) - # Cast to training mode - unwrapped_model.gradient_checkpointing_enable() - unwrapped_model.config.use_cache = False - self.model.train() - # Run PPO step + self.model.train() stats = self.step(queries, responses, rewards) self.tokenizer.padding_side = "left" # restore padding side loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) @@ -303,32 +326,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) return lr_scheduler - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: - super()._save(output_dir, state_dict) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) - @torch.no_grad() - def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: r""" Generates model's responses given queries. """ - if self.model_args.upcast_layernorm: - layernorm_params = dump_layernorm(self.model) - if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() for k, v in batch.items(): batch[k] = v[:, start_index:] - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - generate_output: torch.Tensor = unwrapped_model.generate( - generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch - ) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6 + if self.model_args.upcast_layernorm: + layernorm_params = dump_layernorm(unwrapped_model) - if self.model_args.upcast_layernorm: - restore_layernorm(self.model, layernorm_params) + generate_output: torch.Tensor = unwrapped_model.generate( + generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch + ) + if self.model_args.upcast_layernorm: + restore_layernorm(unwrapped_model, layernorm_params) query = batch["input_ids"].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu() @@ -350,10 +367,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): @torch.no_grad() def get_rewards( self, - queries: List[torch.Tensor], - responses: List[torch.Tensor], - unwrapped_model: "AutoModelForCausalLMWithValueHead", - ) -> List[torch.Tensor]: + queries: List["torch.Tensor"], + responses: List["torch.Tensor"], + ) -> List["torch.Tensor"]: r""" Computes scores using given reward model. @@ -364,18 +380,22 @@ class CustomPPOTrainer(PPOTrainer, Trainer): messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return get_rewards_from_server(self.reward_model, messages) + batch = self.prepare_model_inputs(queries, responses) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="reward") reward_model = self.model else: reward_model = self.reward_model - batch = self.prepare_model_inputs(queries, responses) - - with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 + with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture + if self.finetuning_args.reward_model_type == "lora": + replace_model(unwrapped_model, target="default") + + if self.is_chatglm_model: # assume same architecture values = torch.transpose(values, 0, 1) rewards = [] @@ -384,21 +404,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer): end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type - if self.finetuning_args.reward_model_type == "lora": - replace_model(unwrapped_model, target="default") - return rewards @PPODecorators.empty_device_cache() def batched_forward_pass( self, model: "AutoModelForCausalLMWithValueHead", - queries: torch.Tensor, - responses: torch.Tensor, - model_inputs: dict, + queries: "torch.Tensor", + responses: "torch.Tensor", + model_inputs: Dict[str, Any], return_logits: bool = False, - response_masks: Optional[torch.Tensor] = None, - ): + response_masks: Optional["torch.Tensor"] = None, + ) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]: r""" Calculates model outputs in multiple batches. @@ -420,11 +437,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] - with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 + with self.amp_context: # support bf16 logits, _, values = model(**input_kwargs) - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - if getattr(unwrapped_model.config, "model_type", None) == "chatglm": + if self.is_chatglm_model: values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) @@ -467,14 +483,28 @@ class CustomPPOTrainer(PPOTrainer, Trainer): Subclass and override to inject custom behavior. """ - if self.args.should_save: + if output_dir is None: + output_dir = self.args.output_dir + + if self.is_fsdp_enabled or self.is_deepspeed_enabled: try: - self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) + state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) except ValueError: logger.warning( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " use zero_to_fp32.py to recover weights" ) - self._save(output_dir, state_dict={}) - remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + if self.args.should_save: + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + if self.processor is not None and self.args.should_save: + output_dir = output_dir if output_dir is not None else self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index c4e05e57..4f4d2820 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's TRL library. +# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional @@ -9,7 +24,7 @@ from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_ref_model, create_reward_model +from ..trainer_utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -29,7 +44,7 @@ def run_ppo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) + dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training diff --git a/src/llamafactory/train/pt/__init__.py b/src/llamafactory/train/pt/__init__.py index bdf397f6..d80e6f22 100644 --- a/src/llamafactory/train/pt/__init__.py +++ b/src/llamafactory/train/pt/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_pt diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index b7b80f88..f9e04cb5 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -1,10 +1,25 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os from types import MethodType from typing import TYPE_CHECKING, Dict, Optional from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -28,6 +43,10 @@ class CustomTrainer(Trainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor + + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -46,6 +65,9 @@ class CustomTrainer(Trainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 9f945901..b84a0e7d 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math from typing import TYPE_CHECKING, List, Optional @@ -8,7 +23,7 @@ from transformers import DataCollatorForLanguageModeling from ...data import get_dataset, split_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .trainer import CustomTrainer diff --git a/src/llamafactory/train/rm/__init__.py b/src/llamafactory/train/rm/__init__.py index dedac35f..48278315 100644 --- a/src/llamafactory/train/rm/__init__.py +++ b/src/llamafactory/train/rm/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_rm diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index 99dc6ab8..fb880b1c 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, Sequence, Tuple, Union import numpy as np diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index d49dd67b..7f91e5f5 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -1,3 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the CarperAI's trlx library. +# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2022 CarperAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + import json import os from types import MethodType @@ -7,7 +46,7 @@ import torch from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -50,8 +89,8 @@ class PairwiseTrainer(Trainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) def compute_loss( @@ -79,7 +118,6 @@ class PairwiseTrainer(Trainer): chosen_scores, rejected_scores = [], [] # Compute pairwise loss. Only backprop on the different tokens before padding - # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py loss = 0 for i in range(batch_size): chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 621d03b7..6f24e964 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -1,4 +1,41 @@ -# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# Copyright 2024 the LlamaFactory team. +# +# This code is inspired by the CarperAI's trlx library. +# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2022 CarperAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. from typing import TYPE_CHECKING, List, Optional @@ -7,7 +44,7 @@ from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .metric import compute_accuracy from .trainer import PairwiseTrainer diff --git a/src/llamafactory/train/sft/__init__.py b/src/llamafactory/train/sft/__init__.py index f2f84e78..475dfe5f 100644 --- a/src/llamafactory/train/sft/__init__.py +++ b/src/llamafactory/train/sft/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .workflow import run_sft diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index d1af4c17..95bfcb69 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -1,21 +1,43 @@ +# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union import numpy as np +from transformers.utils import is_jieba_available, is_nltk_available from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available +from ...extras.packages import is_rouge_available if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer + if is_jieba_available(): import jieba # type: ignore + if is_nltk_available(): from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + if is_rouge_available(): from rouge_chinese import Rouge diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index cd73bf5c..0628ea59 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -1,3 +1,20 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from types import MethodType @@ -9,10 +26,11 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from torch.utils.data import Dataset from transformers import ProcessorMixin from transformers.trainer import PredictionOutput @@ -33,6 +51,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor + + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -51,8 +73,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) def training_step(self, *args, **kwargs): @@ -109,7 +134,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding return padded_tensor.contiguous() # in contiguous memory - def save_predictions(self, predict_results: "PredictionOutput") -> None: + def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: r""" Saves model predictions to `output_dir`. @@ -135,6 +160,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1 ) # move pad token to last + decoded_inputs = self.tokenizer.batch_decode( + dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) decoded_labels = self.tokenizer.batch_decode( labels, skip_special_tokens=True, clean_up_tokenization_spaces=False ) @@ -142,6 +170,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] - for label, pred in zip(decoded_labels, decoded_preds): - res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): + res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d9d7c8e9..885bc7ac 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -1,4 +1,19 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, List, Optional @@ -9,7 +24,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer @@ -93,7 +108,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(predict_results) + trainer.save_predictions(dataset, predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/trainer_utils.py similarity index 75% rename from src/llamafactory/train/utils.py rename to src/llamafactory/train/trainer_utils.py index b189922b..98c38842 100644 --- a/src/llamafactory/train/utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -1,11 +1,33 @@ -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore +# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus +# and the original BAdam's implementation: https://github.com/Ledzy/BAdam +# and the HuggingFace's TRL library: https://github.com/huggingface/trl +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch +from peft import PeftModel from transformers import Trainer from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names +from ..extras.constants import IGNORE_INDEX from ..extras.logging import get_logger from ..extras.packages import is_galore_available from ..hparams import FinetuningArguments, ModelArguments @@ -17,8 +39,8 @@ if is_galore_available(): if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments - from transformers.modeling_utils import PreTrainedModel + from accelerate import Accelerator + from transformers import PreTrainedModel, Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments @@ -81,15 +103,12 @@ def create_ref_model( The valuehead parameter is randomly initialized since it is useless for PPO training. """ if finetuning_args.ref_model is not None: - ref_model_args_dict = model_args.to_dict() - ref_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.ref_model, - adapter_name_or_path=finetuning_args.ref_model_adapters, - quantization_bit=finetuning_args.ref_model_quantization_bit, - ) + ref_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.ref_model, + adapter_name_or_path=finetuning_args.ref_model_adapters, + quantization_bit=finetuning_args.ref_model_quantization_bit, ) - ref_model_args = ModelArguments(**ref_model_args_dict) ref_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( @@ -100,9 +119,11 @@ def create_ref_model( if finetuning_args.finetuning_type == "lora": ref_model = None else: - tokenizer = load_tokenizer(model_args)["tokenizer"] + ref_model_args = ModelArguments.copyfrom(model_args) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] ref_model = load_model( - tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ) logger.info("Created reference model from the model itself.") @@ -137,15 +158,12 @@ def create_reward_model( logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) return None else: - reward_model_args_dict = model_args.to_dict() - reward_model_args_dict.update( - dict( - model_name_or_path=finetuning_args.reward_model, - adapter_name_or_path=finetuning_args.reward_model_adapters, - quantization_bit=finetuning_args.reward_model_quantization_bit, - ) + reward_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.reward_model, + adapter_name_or_path=finetuning_args.reward_model_adapters, + quantization_bit=finetuning_args.reward_model_quantization_bit, ) - reward_model_args = ModelArguments(**reward_model_args_dict) reward_finetuning_args = FinetuningArguments() tokenizer = load_tokenizer(reward_model_args)["tokenizer"] reward_model = load_model( @@ -156,6 +174,50 @@ def create_reward_model( return reward_model +def convert_pissa_adapter( + output_dir: str, + state_dict: Dict[str, "torch.Tensor"], + accelerator: "Accelerator", + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", +) -> None: + r""" + Converts the PiSSA adapter to a LoRA adapter. + """ + pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init") + pissa_backup_dir = os.path.join(output_dir, "pissa_backup") + if output_dir == pissa_init_dir: + logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir)) + unwrapped_model = accelerator.unwrap_model(model) + if isinstance(unwrapped_model, PeftModel): + init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights") + setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True) + unwrapped_model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + ) + setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights) + elif output_dir == training_args.output_dir: # at the end of training + logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir)) + unwrapped_model = accelerator.unwrap_model(model) + if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use + unwrapped_model.save_pretrained( + pissa_backup_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + ) + unwrapped_model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + convert_pissa_to_lora=pissa_init_dir, + ) + # TODO: the model is applied pissa again unexpectedly + unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True) + unwrapped_model.set_adapter("default") + + def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: r""" Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) @@ -386,6 +448,7 @@ def create_custom_scheduler( optimizer=optimizer_dict[param], num_warmup_steps=training_args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, + scheduler_specific_kwargs=training_args.lr_scheduler_kwargs, ) def scheduler_hook(param: "torch.nn.Parameter"): @@ -393,3 +456,24 @@ def create_custom_scheduler( for param in optimizer_dict.keys(): param.register_post_accumulate_grad_hook(scheduler_hook) + + +def get_batch_logps( + logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX +) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Computes the log probabilities of the given labels under the given logits. + + Returns: + logps: A tensor of shape (batch_size,) containing the sum of log probabilities. + valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + labels[labels == label_pad_token_id] = 0 # dummy token + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index eed875e9..788b4c4f 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index a92f6ef7..864c41c7 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple @@ -6,6 +20,7 @@ from numpy.typing import NDArray from ..chat import ChatModel from ..data import Role +from ..extras.constants import PEFT_METHODS from ..extras.misc import torch_gc from ..extras.packages import is_gradio_available from .common import get_save_dir @@ -44,13 +59,14 @@ class WebChatModel(ChatModel): def load_model(self, data) -> Generator[str, None, None]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] - lang = get("top.lang") + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path") error = "" if self.loaded: error = ALERTS["err_exists"][lang] - elif not get("top.model_name"): + elif not model_name: error = ALERTS["err_no_model"][lang] - elif not get("top.model_path"): + elif not model_path: error = ALERTS["err_no_path"][lang] elif self.demo_mode: error = ALERTS["err_demo"][lang] @@ -60,21 +76,10 @@ class WebChatModel(ChatModel): yield error return - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - yield ALERTS["info_loading"][lang] args = dict( - model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, - finetuning_type=get("top.finetuning_type"), + model_name_or_path=model_path, + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", @@ -83,8 +88,16 @@ class WebChatModel(ChatModel): rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), ) - super().__init__(args) + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + + super().__init__(args) yield ALERTS["info_loaded"][lang] def unload_model(self, data) -> Generator[str, None, None]: diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index ea82fd88..980428a4 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -1,14 +1,27 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from collections import defaultdict -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple -from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from yaml import safe_dump, safe_load from ..extras.constants import ( + CHECKPOINT_NAMES, DATA_CONFIG, - DEFAULT_MODULE, DEFAULT_TEMPLATE, PEFT_METHODS, STAGES_USE_PAIR_DATA, @@ -29,7 +42,6 @@ if is_gradio_available(): logger = get_logger(__name__) -ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} DEFAULT_CACHE_DIR = "cache" DEFAULT_CONFIG_DIR = "config" DEFAULT_DATA_DIR = "data" @@ -38,19 +50,28 @@ USER_CONFIG = "user_config.yaml" def get_save_dir(*paths: str) -> os.PathLike: - paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths) + r""" + Gets the path to saved model checkpoints. + """ + if os.path.sep in paths[-1]: + logger.warning("Found complex path, some features may be not available.") + return paths[-1] + + paths = (path.replace(" ", "").strip() for path in paths) return os.path.join(DEFAULT_SAVE_DIR, *paths) def get_config_path() -> os.PathLike: + r""" + Gets the path to user config. + """ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def get_save_path(config_path: str) -> os.PathLike: - return os.path.join(DEFAULT_CONFIG_DIR, config_path) - - def load_config() -> Dict[str, Any]: + r""" + Loads user config if exists. + """ try: with open(get_config_path(), "r", encoding="utf-8") as f: return safe_load(f) @@ -59,80 +80,98 @@ def load_config() -> Dict[str, Any]: def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: + r""" + Saves user config. + """ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) user_config = load_config() user_config["lang"] = lang or user_config["lang"] if model_name: user_config["last_model"] = model_name + + if model_name and model_path: user_config["path_dict"][model_name] = model_path + with open(get_config_path(), "w", encoding="utf-8") as f: safe_dump(user_config, f) -def load_args(config_path: str) -> Optional[Dict[str, Any]]: - try: - with open(get_save_path(config_path), "r", encoding="utf-8") as f: - return safe_load(f) - except Exception: - return None - - -def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: - os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) - with open(get_save_path(config_path), "w", encoding="utf-8") as f: - safe_dump(config_dict, f) - - return str(get_save_path(config_path)) - - def get_model_path(model_name: str) -> str: + r""" + Gets the model path according to the model name. + """ user_config = load_config() - path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) - model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None) + path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) + model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "") if ( use_modelscope() and path_dict.get(DownloadSource.MODELSCOPE) and model_path == path_dict.get(DownloadSource.DEFAULT) ): # replace path model_path = path_dict.get(DownloadSource.MODELSCOPE) + return model_path def get_prefix(model_name: str) -> str: + r""" + Gets the prefix of the model name to obtain the model family. + """ return model_name.split("-")[0] -def get_module(model_name: str) -> str: - return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj") +def get_model_info(model_name: str) -> Tuple[str, str, bool]: + r""" + Gets the necessary information of this model. + + Returns: + model_path (str) + template (str) + visual (bool) + """ + return get_model_path(model_name), get_template(model_name), get_visual(model_name) def get_template(model_name: str) -> str: + r""" + Gets the template name if the model is a chat model. + """ if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: return DEFAULT_TEMPLATE[get_prefix(model_name)] return "default" def get_visual(model_name: str) -> bool: + r""" + Judges if the model is a vision language model. + """ return get_prefix(model_name) in VISION_MODELS -def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": - if finetuning_type not in PEFT_METHODS: - return gr.Dropdown(value=[], choices=[], interactive=False) - - adapters = [] - if model_name and finetuning_type == "lora": +def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": + r""" + Lists all available checkpoints. + """ + checkpoints = [] + if model_name: save_dir = get_save_dir(model_name, finetuning_type) if save_dir and os.path.isdir(save_dir): - for adapter in os.listdir(save_dir): - if os.path.isdir(os.path.join(save_dir, adapter)) and any( - os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES + for checkpoint in os.listdir(save_dir): + if os.path.isdir(os.path.join(save_dir, checkpoint)) and any( + os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES ): - adapters.append(adapter) - return gr.Dropdown(value=[], choices=adapters, interactive=True) + checkpoints.append(checkpoint) + + if finetuning_type in PEFT_METHODS: + return gr.Dropdown(value=[], choices=checkpoints, multiselect=True) + else: + return gr.Dropdown(value=None, choices=checkpoints, multiselect=False) def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: + r""" + Loads dataset_info.json. + """ if dataset_dir == "ONLINE": logger.info("dataset_dir is ONLINE, using online dataset.") return {} @@ -145,12 +184,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: return {} -def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": +def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": + r""" + Lists all available datasets in the dataset dir for the training stage. + """ dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] - return gr.Dropdown(value=[], choices=datasets) - - -def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button": - return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt")) + return gr.Dropdown(choices=datasets) diff --git a/src/llamafactory/webui/components/__init__.py b/src/llamafactory/webui/components/__init__.py index 5c1e21b8..715fb6e4 100644 --- a/src/llamafactory/webui/components/__init__.py +++ b/src/llamafactory/webui/components/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .chatbot import create_chat_box from .eval import create_eval_tab from .export import create_export_tab diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index f83694b1..ad74114b 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict, Tuple from ...data import Role diff --git a/src/llamafactory/webui/components/data.py b/src/llamafactory/webui/components/data.py index 232b973d..88e500cf 100644 --- a/src/llamafactory/webui/components/data.py +++ b/src/llamafactory/webui/components/data.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os from typing import TYPE_CHECKING, Any, Dict, List, Tuple diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 8b70283b..b522913e 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -1,7 +1,21 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, list_dataset +from ..common import DEFAULT_DATA_DIR, list_datasets from .data import create_preview_box @@ -57,7 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): output_box = gr.Markdown() - output_elems = [output_box, progress_bar] elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, @@ -68,12 +81,13 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: output_box=output_box, ) ) + output_elems = [output_box, progress_bar] cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None) start_btn.click(engine.runner.run_eval, input_elems, output_elems) stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) + dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False) return elem_dict diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 134b77e0..14257949 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -1,5 +1,20 @@ -from typing import TYPE_CHECKING, Dict, Generator, List +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Dict, Generator, List, Union + +from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc from ...extras.packages import is_gradio_available from ...train.tuner import export_model @@ -20,12 +35,19 @@ if TYPE_CHECKING: GPTQ_BITS = ["8", "4", "3", "2"] +def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown": + if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: + return gr.Dropdown(value="none", interactive=False) + else: + return gr.Dropdown(interactive=True) + + def save_model( lang: str, model_name: str, model_path: str, - adapter_path: List[str], finetuning_type: str, + checkpoint_path: Union[str, List[str]], template: str, visual_inputs: bool, export_size: int, @@ -45,9 +67,9 @@ def save_model( error = ALERTS["err_no_export_dir"][lang] elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset: error = ALERTS["err_no_dataset"][lang] - elif export_quantization_bit not in GPTQ_BITS and not adapter_path: + elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path: error = ALERTS["err_no_adapter"][lang] - elif export_quantization_bit in GPTQ_BITS and adapter_path: + elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list): error = ALERTS["err_gptq_lora"][lang] if error: @@ -55,16 +77,8 @@ def save_model( yield error return - if adapter_path: - adapter_name_or_path = ",".join( - [get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path] - ) - else: - adapter_name_or_path = None - args = dict( model_name_or_path=model_path, - adapter_name_or_path=adapter_name_or_path, finetuning_type=finetuning_type, template=template, visual_inputs=visual_inputs, @@ -77,6 +91,14 @@ def save_model( export_legacy_format=export_legacy_format, ) + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + yield ALERTS["info_exporting"][lang] export_model(args) torch_gc() @@ -86,15 +108,18 @@ def save_model( def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1) - export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") + export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") - export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu") + export_device = gr.Radio(choices=["cpu", "auto"], value="cpu") export_legacy_format = gr.Checkbox() with gr.Row(): export_dir = gr.Textbox() export_hub_model_id = gr.Textbox() + checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") + checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) + export_btn = gr.Button() info_box = gr.Textbox(show_label=False, interactive=False) @@ -104,8 +129,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: engine.manager.get_elem_by_id("top.lang"), engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.model_path"), - engine.manager.get_elem_by_id("top.adapter_path"), engine.manager.get_elem_by_id("top.finetuning_type"), + engine.manager.get_elem_by_id("top.checkpoint_path"), engine.manager.get_elem_by_id("top.template"), engine.manager.get_elem_by_id("top.visual_inputs"), export_size, diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 970f4629..03bccd7f 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index a75a4d62..2515a83d 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -1,9 +1,23 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict -from ...data import templates +from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available -from ..common import get_model_path, get_template, get_visual, list_adapters, save_config +from ..common import get_model_info, list_checkpoints, save_config from ..utils import can_quantize @@ -25,38 +39,28 @@ def create_top() -> Dict[str, "Component"]: with gr.Row(): finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) - adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5) - refresh_btn = gr.Button(scale=1) + checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) with gr.Accordion(open=False) as advanced_tab: with gr.Row(): quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) - template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) visual_inputs = gr.Checkbox(scale=1) - model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( - get_model_path, [model_name], [model_path], queue=False - ).then(get_template, [model_name], [template], queue=False).then( - get_visual, [model_name], [visual_inputs], queue=False - ) # do not save config since the below line will save - - model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) - - finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( - can_quantize, [finetuning_type], [quantization_bit], queue=False - ) - - refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False) + model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False) + model_name.input(save_config, inputs=[lang, model_name], queue=False) + model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) + finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False) + checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) return dict( lang=lang, model_name=model_name, model_path=model_path, finetuning_type=finetuning_type, - adapter_path=adapter_path, - refresh_btn=refresh_btn, + checkpoint_path=checkpoint_path, advanced_tab=advanced_tab, quantization_bit=quantization_bit, template=template, diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index d399106f..874f3c5e 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -1,11 +1,27 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict from transformers.trainer_utils import SchedulerType from ...extras.constants import TRAINING_STAGES +from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset -from ..components.data import create_preview_box +from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets +from ..utils import change_stage, list_config_paths, list_output_dirs +from .data import create_preview_box if is_gradio_available(): @@ -147,10 +163,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter = gr.Checkbox() with gr.Row(): - with gr.Column(scale=1): - use_rslora = gr.Checkbox() - use_dora = gr.Checkbox() - + use_rslora = gr.Checkbox() + use_dora = gr.Checkbox() + use_pissa = gr.Checkbox() lora_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2) @@ -163,6 +178,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter, use_rslora, use_dora, + use_pissa, lora_target, additional_target, } @@ -177,6 +193,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter=create_new_adapter, use_rslora=use_rslora, use_dora=use_dora, + use_pissa=use_pissa, lora_target=lora_target, additional_target=additional_target, ) @@ -255,8 +272,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): with gr.Column(scale=3): with gr.Row(): - output_dir = gr.Textbox() - config_path = gr.Textbox() + current_time = gr.Textbox(visible=False, interactive=False) + output_dir = gr.Dropdown(allow_custom_value=True) + config_path = gr.Dropdown(allow_custom_value=True) + + with gr.Row(): + device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False) + ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none") + ds_offload = gr.Checkbox() with gr.Row(): resume_btn = gr.Checkbox(visible=False, interactive=False) @@ -268,6 +291,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Column(scale=1): loss_viewer = gr.Plot() + input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload}) elem_dict.update( dict( cmd_preview_btn=cmd_preview_btn, @@ -275,36 +299,48 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: arg_load_btn=arg_load_btn, start_btn=start_btn, stop_btn=stop_btn, + current_time=current_time, output_dir=output_dir, config_path=config_path, + device_count=device_count, + ds_stage=ds_stage, + ds_offload=ds_offload, resume_btn=resume_btn, progress_bar=progress_bar, output_box=output_box, loss_viewer=loss_viewer, ) ) - - input_elems.update({output_dir, config_path}) output_elems = [output_box, progress_bar, loss_viewer] cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) - arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) - arg_load_btn.click( - engine.runner.load_args, - [engine.manager.get_elem_by_id("top.lang"), config_path], - list(input_elems) + [output_box], - concurrency_limit=None, - ) start_btn.click(engine.runner.run_train, input_elems, output_elems) stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( - list_adapters, - [engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")], - [reward_model], - queue=False, - ).then(autoset_packing, [training_stage], [packing], queue=False) + lang = engine.manager.get_elem_by_id("top.lang") + model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") + finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") + + arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) + arg_load_btn.click( + engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None + ) + + dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) + model_name.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) + finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) + output_dir.change( + list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None + ) + output_dir.input( + engine.runner.check_output_dir, + [lang, model_name, finetuning_type, output_dir], + list(input_elems) + [output_box], + concurrency_limit=None, + ) + config_path.change(list_config_paths, [current_time], [config_path], queue=False) return elem_dict diff --git a/src/llamafactory/webui/css.py b/src/llamafactory/webui/css.py index 36e3d4c2..53982119 100644 --- a/src/llamafactory/webui/css.py +++ b/src/llamafactory/webui/css.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + CSS = r""" .duplicate-button { margin: auto !important; diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index 964d65a2..04893215 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -1,11 +1,25 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Any, Dict from .chatter import WebChatModel -from .common import get_model_path, list_dataset, load_config +from .common import load_config from .locales import LOCALES from .manager import Manager from .runner import Runner -from .utils import get_time +from .utils import create_ds_config, get_time if TYPE_CHECKING: @@ -19,6 +33,8 @@ class Engine: self.manager = Manager() self.runner = Runner(self.manager, demo_mode) self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) + if not demo_mode: + create_ds_config() def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: r""" @@ -38,16 +54,15 @@ class Engine: init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} if not self.pure_chat: - init_dict["train.dataset"] = {"choices": list_dataset().choices} - init_dict["eval.dataset"] = {"choices": list_dataset().choices} - init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} - init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())} - init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} + current_time = get_time() + init_dict["train.current_time"] = {"value": current_time} + init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} + init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} + init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} init_dict["infer.image_box"] = {"visible": False} if user_config.get("last_model", None): init_dict["top.model_name"] = {"value": user_config["last_model"]} - init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])} yield self._update_component(init_dict) diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index bae3ba76..d25f4d38 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from ..extras.packages import is_gradio_available diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index bd4a4205..8e8d6fce 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + LOCALES = { "lang": { "en": { @@ -46,26 +60,15 @@ LOCALES = { "label": "微调方法", }, }, - "adapter_path": { + "checkpoint_path": { "en": { - "label": "Adapter path", + "label": "Checkpoint path", }, "ru": { - "label": "Путь к адаптеру", + "label": "Путь контрольной точки", }, "zh": { - "label": "适配器路径", - }, - }, - "refresh_btn": { - "en": { - "value": "Refresh adapters", - }, - "ru": { - "value": "Обновить адаптеры", - }, - "zh": { - "value": "刷新适配器", + "label": "检查点路径", }, }, "advanced_tab": { @@ -729,6 +732,20 @@ LOCALES = { "info": "使用权重分解的 LoRA。", }, }, + "use_pissa": { + "en": { + "label": "Use PiSSA", + "info": "Use PiSSA method.", + }, + "ru": { + "label": "используйте PiSSA", + "info": "Используйте метод PiSSA.", + }, + "zh": { + "label": "使用 PiSSA", + "info": "使用 PiSSA 方法。", + }, + }, "lora_target": { "en": { "label": "LoRA modules (optional)", @@ -1103,6 +1120,48 @@ LOCALES = { "info": "保存训练参数的配置文件路径。", }, }, + "device_count": { + "en": { + "label": "Device count", + "info": "Number of devices available.", + }, + "ru": { + "label": "Количество устройств", + "info": "Количество доступных устройств.", + }, + "zh": { + "label": "设备数量", + "info": "当前可用的运算设备数。", + }, + }, + "ds_stage": { + "en": { + "label": "DeepSpeed stage", + "info": "DeepSpeed stage for distributed training.", + }, + "ru": { + "label": "Этап DeepSpeed", + "info": "Этап DeepSpeed для распределенного обучения.", + }, + "zh": { + "label": "DeepSpeed stage", + "info": "多卡训练的 DeepSpeed stage。", + }, + }, + "ds_offload": { + "en": { + "label": "Enable offload", + "info": "Enable DeepSpeed offload (slow down training).", + }, + "ru": { + "label": "Включить выгрузку", + "info": "включить выгрузку DeepSpeed (замедлит обучение).", + }, + "zh": { + "label": "使用 offload", + "info": "使用 DeepSpeed offload(会减慢速度)。", + }, + }, "output_box": { "en": { "value": "Ready.", @@ -1444,6 +1503,11 @@ ALERTS = { "ru": "Пожалуйста, выберите адаптер.", "zh": "请选择适配器。", }, + "err_no_output_dir": { + "en": "Please provide output dir.", + "ru": "Пожалуйста, укажите выходную директорию.", + "zh": "请填写输出目录。", + }, "err_no_reward_model": { "en": "Please select a reward model.", "ru": "Пожалуйста, выберите модель вознаграждения.", @@ -1469,11 +1533,6 @@ ALERTS = { "ru": "Обучение недоступно в демонстрационном режиме, сначала скопируйте пространство в частное.", "zh": "展示模式不支持训练,请先复制到私人空间。", }, - "err_device_count": { - "en": "Multiple GPUs are not supported yet.", - "ru": "Пока не поддерживается множественные GPU.", - "zh": "尚不支持多 GPU 训练。", - }, "err_tool_name": { "en": "Tool name not found.", "ru": "Имя инструмента не найдено.", @@ -1494,6 +1553,11 @@ ALERTS = { "ru": "Среда CUDA не обнаружена.", "zh": "未检测到 CUDA 环境。", }, + "warn_output_dir_exists": { + "en": "Output dir already exists, will resume training from here.", + "ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.", + "zh": "输出目录已存在,将从该断点恢复训练。", + }, "info_aborting": { "en": "Aborted, wait for terminating...", "ru": "Прервано, ожидание завершения...", diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index f65fa804..7e9b801a 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple @@ -55,7 +69,7 @@ class Manager: self._id_to_elem["top.model_name"], self._id_to_elem["top.model_path"], self._id_to_elem["top.finetuning_type"], - self._id_to_elem["top.adapter_path"], + self._id_to_elem["top.checkpoint_path"], self._id_to_elem["top.quantization_bit"], self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 57595a08..13dbba03 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -1,19 +1,30 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -import signal from copy import deepcopy from subprocess import Popen, TimeoutExpired from typing import TYPE_CHECKING, Any, Dict, Generator, Optional -import psutil from transformers.trainer import TRAINING_ARGS_NAME -from transformers.utils import is_torch_cuda_available -from ..extras.constants import TRAINING_STAGES -from ..extras.misc import get_device_count, torch_gc +from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES +from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import get_module, get_save_dir, load_args, load_config, save_args -from .locales import ALERTS -from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config +from .locales import ALERTS, LOCALES +from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd if is_gradio_available(): @@ -41,8 +52,7 @@ class Runner: def set_abort(self) -> None: self.aborted = True if self.trainer is not None: - for children in psutil.Process(self.trainer.pid).children(): # abort the child process - os.kill(children.pid, signal.SIGABRT) + abort_leaf_process(self.trainer.pid) def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] @@ -64,16 +74,18 @@ class Runner: if not from_preview and self.demo_mode: return ALERTS["err_demo"][lang] - if not from_preview and get_device_count() > 1: - return ALERTS["err_device_count"][lang] - if do_train: - stage = TRAINING_STAGES[get("train.training_stage")] - reward_model = get("train.reward_model") - if stage == "ppo" and not reward_model: - return ALERTS["err_no_reward_model"][lang] + if not get("train.output_dir"): + return ALERTS["err_no_output_dir"][lang] - if not from_preview and not is_torch_cuda_available(): + stage = TRAINING_STAGES[get("train.training_stage")] + if stage == "ppo" and not get("train.reward_model"): + return ALERTS["err_no_reward_model"][lang] + else: + if not get("eval.output_dir"): + return ALERTS["err_no_output_dir"][lang] + + if not from_preview and not is_gpu_or_npu_available(): gr.Warning(ALERTS["warn_no_cuda"][lang]) return "" @@ -89,26 +101,16 @@ class Runner: def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - args = dict( stage=TRAINING_STAGES[get("train.training_stage")], do_train=True, model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, - finetuning_type=get("top.finetuning_type"), + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, @@ -138,13 +140,24 @@ class Runner: report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), use_badam=get("train.use_badam"), - output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), + output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), fp16=(get("train.compute_type") == "fp16"), bf16=(get("train.compute_type") == "bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"), plot_loss=True, + ddp_timeout=180000000, + include_num_input_tokens_seen=True, ) + # checkpoints + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + # freeze config if args["finetuning_type"] == "freeze": args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") @@ -160,7 +173,9 @@ class Runner: args["create_new_adapter"] = get("train.create_new_adapter") args["use_rslora"] = get("train.use_rslora") args["use_dora"] = get("train.use_dora") - args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name")) + args["pissa_init"] = get("train.use_pissa") + args["pissa_convert"] = get("train.use_pissa") + args["lora_target"] = get("train.lora_target") or "all" args["additional_target"] = get("train.additional_target") or None if args["use_llama_pro"]: @@ -168,13 +183,14 @@ class Runner: # rlhf config if args["stage"] == "ppo": - args["reward_model"] = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("train.reward_model") - ] - ) - args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" + if finetuning_type in PEFT_METHODS: + args["reward_model"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")] + ) + else: + args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model")) + + args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full" args["ppo_score_norm"] = get("train.ppo_score_norm") args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") args["top_k"] = 0 @@ -201,33 +217,29 @@ class Runner: # eval config if get("train.val_size") > 1e-6 and args["stage"] != "ppo": args["val_size"] = get("train.val_size") - args["evaluation_strategy"] = "steps" + args["eval_strategy"] = "steps" args["eval_steps"] = args["save_steps"] args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] + # ds config + if get("train.ds_stage") != "none": + ds_stage = get("train.ds_stage") + ds_offload = "offload_" if get("train.ds_offload") else "" + args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload)) + return args def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - args = dict( stage="sft", model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, - finetuning_type=get("top.finetuning_type"), + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, @@ -243,7 +255,7 @@ class Runner: max_new_tokens=get("eval.max_new_tokens"), top_p=get("eval.top_p"), temperature=get("eval.temperature"), - output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")), + output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), ) if get("eval.predict"): @@ -251,6 +263,14 @@ class Runner: else: args["do_eval"] = True + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + return args def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: @@ -272,12 +292,28 @@ class Runner: else: self.do_train, self.running_data = do_train, data args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + + os.makedirs(args["output_dir"], exist_ok=True) + save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data)) + env = deepcopy(os.environ) - env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0") env["LLAMABOARD_ENABLED"] = "1" + if args.get("deepspeed", None) is not None: + env["FORCE_TORCHRUN"] = "1" + self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) yield from self.monitor() + def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: + config_dict = {} + skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"] + for elem, value in data.items(): + elem_id = self.manager.get_id_by_elem(elem) + if elem_id not in skip_ids: + config_dict[elem_id] = value + + return config_dict + def preview_train(self, data): yield from self._preview(data, do_train=True) @@ -295,9 +331,7 @@ class Runner: self.running = True get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] - lang = get("top.lang") - model_name = get("top.model_name") - finetuning_type = get("top.finetuning_type") + lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type") output_dir = get("{}.output_dir".format("train" if self.do_train else "eval")) output_path = get_save_dir(model_name, finetuning_type, output_dir) @@ -345,28 +379,24 @@ class Runner: } yield return_dict - def save_args(self, data: dict): + def save_args(self, data): output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: gr.Warning(error) return {output_box: error} - config_dict: Dict[str, Any] = {} lang = data[self.manager.get_elem_by_id("top.lang")] config_path = data[self.manager.get_elem_by_id("train.config_path")] - skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] - for elem, value in data.items(): - elem_id = self.manager.get_id_by_elem(elem) - if elem_id not in skip_ids: - config_dict[elem_id] = value + os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) + save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path) - save_path = save_args(config_path, config_dict) + save_args(save_path, self._form_config_dict(data)) return {output_box: ALERTS["info_config_saved"][lang] + save_path} def load_args(self, lang: str, config_path: str): output_box = self.manager.get_elem_by_id("train.output_box") - config_dict = load_args(config_path) + config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) if config_dict is None: gr.Warning(ALERTS["err_config_not_found"][lang]) return {output_box: ALERTS["err_config_not_found"][lang]} @@ -376,3 +406,17 @@ class Runner: output_dict[self.manager.get_elem_by_id(elem_id)] = value return output_dict + + def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): + output_box = self.manager.get_elem_by_id("train.output_box") + output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} + if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): + gr.Warning(ALERTS["warn_output_dir_exists"][lang]) + output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang] + + output_dir = get_save_dir(model_name, finetuning_type, output_dir) + config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config + for elem_id, value in config_dict.items(): + output_dict[self.manager.get_elem_by_id(elem_id)] = value + + return output_dict diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 3d34f0d2..6ce2a8e7 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -1,13 +1,31 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os +import signal from datetime import datetime from typing import Any, Dict, List, Optional, Tuple -from yaml import safe_dump +import psutil +from transformers.trainer_utils import get_last_checkpoint +from yaml import safe_dump, safe_load -from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG +from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir from .locales import ALERTS @@ -15,14 +33,39 @@ if is_gradio_available(): import gradio as gr +def abort_leaf_process(pid: int) -> None: + r""" + Aborts the leaf processes. + """ + children = psutil.Process(pid).children() + if children: + for child in children: + abort_leaf_process(child.pid) + else: + os.kill(pid, signal.SIGABRT) + + def can_quantize(finetuning_type: str) -> "gr.Dropdown": - if finetuning_type != "lora": + r""" + Judges if the quantization is available in this finetuning type. + """ + if finetuning_type not in PEFT_METHODS: return gr.Dropdown(value="none", interactive=False) else: return gr.Dropdown(interactive=True) +def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: + r""" + Modifys states after changing the training stage. + """ + return [], TRAINING_STAGES[training_stage] == "pt" + + def check_json_schema(text: str, lang: str) -> None: + r""" + Checks if the json schema is valid. + """ try: tools = json.loads(text) if tools: @@ -37,13 +80,18 @@ def check_json_schema(text: str, lang: str) -> None: def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: + r""" + Removes args with NoneType or False or empty string value. + """ no_skip_keys = ["packing"] return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} def gen_cmd(args: Dict[str, Any]) -> str: - current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0") - cmd_lines = ["CUDA_VISIBLE_DEVICES={} llamafactory-cli train ".format(current_devices)] + r""" + Generates arguments for previewing. + """ + cmd_lines = ["llamafactory-cli train "] for k, v in clean_cmd(args).items(): cmd_lines.append(" --{} {} ".format(k, str(v))) @@ -52,17 +100,39 @@ def gen_cmd(args: Dict[str, Any]) -> str: return cmd_text +def save_cmd(args: Dict[str, Any]) -> str: + r""" + Saves arguments to launch training. + """ + output_dir = args["output_dir"] + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: + safe_dump(clean_cmd(args), f) + + return os.path.join(output_dir, TRAINING_ARGS) + + def get_eval_results(path: os.PathLike) -> str: + r""" + Gets scores after evaluation. + """ with open(path, "r", encoding="utf-8") as f: result = json.dumps(json.load(f), indent=4) return "```json\n{}\n```\n".format(result) def get_time() -> str: + r""" + Gets current date and time. + """ return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: + r""" + Gets training infomation for monitor. + """ running_log = "" running_progress = gr.Slider(visible=False) running_loss = None @@ -96,11 +166,112 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr return running_log, running_progress, running_loss -def save_cmd(args: Dict[str, Any]) -> str: - output_dir = args["output_dir"] - os.makedirs(output_dir, exist_ok=True) +def load_args(config_path: str) -> Optional[Dict[str, Any]]: + r""" + Loads saved arguments. + """ + try: + with open(config_path, "r", encoding="utf-8") as f: + return safe_load(f) + except Exception: + return None - with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f: - safe_dump(clean_cmd(args), f) - return os.path.join(output_dir, TRAINER_CONFIG) +def save_args(config_path: str, config_dict: Dict[str, Any]): + r""" + Saves arguments. + """ + with open(config_path, "w", encoding="utf-8") as f: + safe_dump(config_dict, f) + + +def list_config_paths(current_time: str) -> "gr.Dropdown": + r""" + Lists all the saved configuration files. + """ + config_files = ["{}.yaml".format(current_time)] + if os.path.isdir(DEFAULT_CONFIG_DIR): + for file_name in os.listdir(DEFAULT_CONFIG_DIR): + if file_name.endswith(".yaml") and file_name not in config_files: + config_files.append(file_name) + + return gr.Dropdown(choices=config_files) + + +def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": + r""" + Lists all the directories that can resume from. + """ + output_dirs = ["train_{}".format(current_time)] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for folder in os.listdir(save_dir): + output_dir = os.path.join(save_dir, folder) + if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None: + output_dirs.append(folder) + + return gr.Dropdown(choices=output_dirs) + + +def create_ds_config() -> None: + r""" + Creates deepspeed config. + """ + os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) + ds_config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": True, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1, + }, + "bf16": {"enabled": "auto"}, + } + offload_config = { + "device": "cpu", + "pin_memory": True, + } + ds_config["zero_optimization"] = { + "stage": 2, + "allgather_partitions": True, + "allgather_bucket_size": 5e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "contiguous_gradients": True, + "round_robin_gradients": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"] = { + "stage": 3, + "overlap_comm": True, + "contiguous_gradients": True, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": True, + } + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) + + ds_config["zero_optimization"]["offload_optimizer"] = offload_config + ds_config["zero_optimization"]["offload_param"] = offload_config + with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f: + json.dump(ds_config, f, indent=2) diff --git a/src/train.py b/src/train.py index b20aa9d2..6703ffdb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from llamafactory.train.tuner import run_exp diff --git a/src/webui.py b/src/webui.py index bbefb54e..99370af2 100644 --- a/src/webui.py +++ b/src/webui.py @@ -1,3 +1,17 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from llamafactory.webui.interface import create_ui diff --git a/tests/data/test_supervised.py b/tests/data/test_supervised.py new file mode 100644 index 00000000..9f7b2dbf --- /dev/null +++ b/tests/data/test_supervised.py @@ -0,0 +1,64 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random + +import pytest +from datasets import load_dataset + +from llamafactory.data import get_dataset +from llamafactory.hparams import get_train_args +from llamafactory.model import load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "full", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 8192, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + + +@pytest.mark.parametrize("num_samples", [16]) +def test_supervised(num_samples: int): + model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS) + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + + original_data = load_dataset(TRAIN_ARGS["dataset"], split="train") + indexes = random.choices(range(len(original_data)), k=num_samples) + for index in indexes: + decoded_result = tokenizer.decode(tokenized_data["input_ids"][index]) + prompt = original_data[index]["instruction"] + if original_data[index]["input"]: + prompt += "\n" + original_data[index]["input"] + + messages = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": original_data[index]["output"]}, + ] + templated_result = tokenizer.apply_chat_template(messages, tokenize=False) + assert decoded_result == templated_result diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py new file mode 100644 index 00000000..f85d9d57 --- /dev/null +++ b/tests/eval/test_eval_template.py @@ -0,0 +1,91 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.eval.template import get_eval_template + + +def test_eval_template_en(): + support_set = [ + { + "question": "Fewshot question", + "A": "Fewshot1", + "B": "Fewshot2", + "C": "Fewshot3", + "D": "Fewshot4", + "answer": "B", + } + ] + example = { + "question": "Target question", + "A": "Target1", + "B": "Target2", + "C": "Target3", + "D": "Target4", + "answer": "C", + } + template = get_eval_template(name="en") + messages = template.format_example(example, support_set=support_set, subject_name="SubName") + assert messages == [ + { + "role": "user", + "content": ( + "The following are multiple choice questions (with answers) about SubName.\n\n" + "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:", + }, + {"role": "assistant", "content": "C"}, + ] + + +def test_eval_template_zh(): + support_set = [ + { + "question": "示例问题", + "A": "示例答案1", + "B": "示例答案2", + "C": "示例答案3", + "D": "示例答案4", + "answer": "B", + } + ] + example = { + "question": "目标问题", + "A": "目标答案1", + "B": "目标答案2", + "C": "目标答案3", + "D": "目标答案4", + "answer": "C", + } + template = get_eval_template(name="zh") + messages = template.format_example(example, support_set=support_set, subject_name="主题") + assert messages == [ + { + "role": "user", + "content": ( + "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n" + "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:", + }, + {"role": "assistant", "content": "C"}, + ] diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py new file mode 100644 index 00000000..97ac9dcc --- /dev/null +++ b/tests/model/model_utils/test_attention.py @@ -0,0 +1,50 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "template": "llama3", +} + + +def test_attention(): + attention_available = ["off"] + if is_torch_sdpa_available(): + attention_available.append("sdpa") + + if is_flash_attn_2_available(): + attention_available.append("fa2") + + llama_attention_classes = { + "off": "LlamaAttention", + "sdpa": "LlamaSdpaAttention", + "fa2": "LlamaFlashAttention2", + } + for requested_attention in attention_available: + model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + assert module.__class__.__name__ == llama_attention_classes[requested_attention] diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py new file mode 100644 index 00000000..670e693d --- /dev/null +++ b/tests/model/model_utils/test_checkpointing.py @@ -0,0 +1,74 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "lora_target": "all", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + + +def test_checkpointing_enable(): + model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): + assert getattr(module, "gradient_checkpointing") is True + + +def test_checkpointing_disable(): + model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): + assert getattr(module, "gradient_checkpointing") is False + + +def test_upcast_layernorm(): + model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + for name, param in model.named_parameters(): + if param.ndim == 1 and "norm" in name: + assert param.dtype == torch.float32 + + +def test_upcast_lmhead_output(): + model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) + outputs: "torch.Tensor" = model.lm_head(inputs) + assert outputs.dtype == torch.float32 diff --git a/tests/model/test_base.py b/tests/model/test_base.py new file mode 100644 index 00000000..e1991b20 --- /dev/null +++ b/tests/model/test_base.py @@ -0,0 +1,79 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict + +import pytest +import torch +from transformers import AutoModelForCausalLM +from trl import AutoModelForCausalLMWithValueHead + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "template": "llama3", + "infer_dtype": "float16", +} + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) + + +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init + + +def test_base(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + ref_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() + ) + compare_model(model, ref_model) + + +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") +def test_valuehead(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model( + tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True + ) + + ref_model = AutoModelForCausalLMWithValueHead.from_pretrained( + TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() + ) + compare_model(model, ref_model) diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py new file mode 100644 index 00000000..5f478af6 --- /dev/null +++ b/tests/model/test_freeze.py @@ -0,0 +1,85 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "freeze", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "finetuning_type": "freeze", + "template": "llama3", + "infer_dtype": "float16", +} + + +def test_freeze_train_all_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + for name, param in model.named_parameters(): + if name.startswith("model.layers.1."): + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + +def test_freeze_train_extra_modules(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + for name, param in model.named_parameters(): + if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]): + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + +def test_freeze_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + for param in model.parameters(): + assert param.requires_grad is False + assert param.dtype == torch.float16 diff --git a/tests/model/test_full.py b/tests/model/test_full.py new file mode 100644 index 00000000..0a6e0743 --- /dev/null +++ b/tests/model/test_full.py @@ -0,0 +1,65 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "full", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "finetuning_type": "full", + "template": "llama3", + "infer_dtype": "float16", +} + + +def test_full_train(): + model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + for param in model.parameters(): + assert param.requires_grad is True + assert param.dtype == torch.float32 + + +def test_full_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + for param in model.parameters(): + assert param.requires_grad is False + assert param.dtype == torch.float16 diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py new file mode 100644 index 00000000..630e5f75 --- /dev/null +++ b/tests/model/test_lora.py @@ -0,0 +1,198 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict, Sequence + +import pytest +import torch +from peft import LoraModel, PeftModel +from transformers import AutoModelForCausalLM +from trl import AutoModelForCausalLMWithValueHead + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora") + +TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA, + "adapter_name_or_path": TINY_LLAMA_ADAPTER, + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", +} + + +def load_reference_model(is_trainable: bool = False) -> "LoraModel": + model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device() + ) + lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable) + for param in filter(lambda p: p.requires_grad, lora_model.parameters()): + param.data = param.data.to(torch.float32) + + return lora_model + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + if any(key in name for key in diff_keys): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False + else: + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True + + +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init + + +def test_lora_train_qv_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + linear_modules = set() + for name, param in model.named_parameters(): + if any(module in name for module in ["lora_A", "lora_B"]): + linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + assert linear_modules == {"q_proj", "v_proj"} + + +def test_lora_train_all_modules(): + model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS}) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + linear_modules = set() + for name, param in model.named_parameters(): + if any(module in name for module in ["lora_A", "lora_B"]): + linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"} + + +def test_lora_train_extra_modules(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + extra_modules = set() + for name, param in model.named_parameters(): + if any(module in name for module in ["lora_A", "lora_B"]): + assert param.requires_grad is True + assert param.dtype == torch.float32 + elif "modules_to_save" in name: + extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1]) + assert param.requires_grad is True + assert param.dtype == torch.float32 + else: + assert param.requires_grad is False + assert param.dtype == torch.float16 + + assert extra_modules == {"embed_tokens", "lm_head"} + + +def test_lora_train_old_adapters(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + ref_model = load_reference_model(is_trainable=True) + compare_model(model, ref_model) + + +def test_lora_train_new_adapters(): + model_args, _, _, finetuning_args, _ = get_train_args( + {"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS} + ) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + ref_model = load_reference_model(is_trainable=True) + compare_model( + model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] + ) + + +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") +def test_lora_train_valuehead(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model( + tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True + ) + + ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( + TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() + ) + state_dict = model.state_dict() + ref_state_dict = ref_model.state_dict() + + assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"]) + assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"]) + + +def test_lora_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + ref_model = load_reference_model().merge_and_unload() + compare_model(model, ref_model) diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py new file mode 100644 index 00000000..030310d0 --- /dev/null +++ b/tests/model/test_pissa.py @@ -0,0 +1,90 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from peft import LoraModel, PeftModel +from transformers import AutoModelForCausalLM + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "pissa_init": True, + "pissa_iter": -1, + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA_PISSA, + "adapter_name_or_path": TINY_LLAMA_PISSA, + "adapter_folder": "pissa_init", + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", +} + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) + + +def test_pissa_init(): + model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True) + for param in filter(lambda p: p.requires_grad, ref_model.parameters()): + param.data = param.data.to(torch.float32) + + compare_model(model, ref_model) + + +def test_pissa_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init") + ref_model = ref_model.merge_and_unload() + compare_model(model, ref_model) diff --git a/tests/test_throughput.py b/tests/test_throughput.py deleted file mode 100644 index e8048910..00000000 --- a/tests/test_throughput.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import time - -from openai import OpenAI -from transformers.utils.versions import require_version - - -require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") - - -def main(): - client = OpenAI( - api_key="0", - base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), - ) - messages = [{"role": "user", "content": "Write a long essay about environment protection as long as possible."}] - num_tokens = 0 - start_time = time.time() - for _ in range(8): - result = client.chat.completions.create(messages=messages, model="test") - num_tokens += result.usage.completion_tokens - - elapsed_time = time.time() - start_time - print("Throughput: {:.2f} tokens/s".format(num_tokens / elapsed_time)) - # --infer_backend hf: 27.22 tokens/s (1.0x) - # --infer_backend vllm: 73.03 tokens/s (2.7x) - - -if __name__ == "__main__": - main()