Merge remote-tracking branch 'upstream/main'

Former-commit-id: ea1f3ba5e030504e07053484f50f4cbdb37808bc
This commit is contained in:
Jonery 2024-06-17 18:44:51 +08:00
commit 5d59f6562a
184 changed files with 5411 additions and 1780 deletions

View File

@ -4,6 +4,8 @@
.venv .venv
cache cache
data data
hf_cache
output
examples examples
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@ -13,6 +13,18 @@ body:
- label: I have read the README and searched the existing issues. - label: I have read the README and searched the existing issues.
required: true required: true
- type: textarea
id: system-info
validations:
required: true
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
placeholder: llamafactory version, platform, python version, ...
- type: textarea - type: textarea
id: reproduction id: reproduction
validations: validations:
@ -26,7 +38,9 @@ body:
请合理使用 Markdown 标签来格式化您的文本。 请合理使用 Markdown 标签来格式化您的文本。
placeholder: | placeholder: |
python src/train_bash.py ... ```bash
llamafactory-cli train ...
```
- type: textarea - type: textarea
id: expected-behavior id: expected-behavior
@ -38,18 +52,6 @@ body:
Please provide a clear and concise description of what you would expect to happen. Please provide a clear and concise description of what you would expect to happen.
请提供您原本的目的,即这段代码的期望行为。 请提供您原本的目的,即这段代码的期望行为。
- type: textarea
id: system-info
validations:
required: false
attributes:
label: System Info
description: |
Please share your system info with us. You can run the command **transformers-cli env** and copy-paste its output below.
请提供您的系统信息。您可以在命令行运行 **transformers-cli env** 并将其输出复制到该文本框中。
placeholder: transformers version, platform, python version, ...
- type: textarea - type: textarea
id: others id: others
validations: validations:

View File

@ -5,3 +5,4 @@ Fixes # (issue)
## Before submitting ## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)? - [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
- [ ] Did you write any new necessary tests?

17
.github/workflows/label_issue.yml vendored Normal file
View File

@ -0,0 +1,17 @@
name: label_issue
on:
issues:
types:
- opened
jobs:
label_issue:
runs-on: ubuntu-latest
steps:
- env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
ISSUE_URL: ${{ github.event.issue.html_url }}
run: |
gh issue edit $ISSUE_URL --add-label "pending"

View File

@ -2,28 +2,44 @@ name: tests
on: on:
push: push:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
pull_request: pull_request:
branches: [ "main" ] branches:
- main
paths:
- "**.py"
- "requirements.txt"
- ".github/workflows/*.yml"
jobs: jobs:
check_code_quality: tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - name: Checkout
uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.8"
cache: "pip"
cache-dependency-path: "setup.py"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ruff python -m pip install .[torch,dev]
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
- name: Test with pytest
run: |
make test

View File

@ -1,14 +1,44 @@
FROM nvcr.io/nvidia/pytorch:24.01-py3 # Use the NVIDIA official image with PyTorch 2.3.0
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
FROM nvcr.io/nvidia/pytorch:24.02-py3
# Define installation arguments
ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false
ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory
WORKDIR /app WORKDIR /app
# Install the requirements
COPY requirements.txt /app/ COPY requirements.txt /app/
RUN pip install -r requirements.txt RUN pip config set global.index-url $PIP_INDEX
RUN python -m pip install --upgrade pip
RUN python -m pip install -r requirements.txt
# Copy the rest of the application into the image
COPY . /app/ COPY . /app/
RUN pip install -e .[metrics,bitsandbytes,qwen]
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_BNB" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
fi; \
if [ "$INSTALL_VLLM" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
fi; \
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \
pip install -e .[$EXTRA_PACKAGES] && \
pip uninstall -y transformer-engine flash-attn
# Set up volumes
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
# Expose port 7860 for the LLaMA Board
EXPOSE 7860 EXPOSE 7860
CMD [ "llamafactory-cli", "webui" ] # Expose port 8000 for the API service
EXPOSE 8000

1
MANIFEST.in Normal file
View File

@ -0,0 +1 @@
include LICENSE requirements.txt

View File

@ -1,4 +1,4 @@
.PHONY: quality style .PHONY: quality style test
check_dirs := scripts src tests check_dirs := scripts src tests
@ -9,3 +9,6 @@ quality:
style: style:
ruff check $(check_dirs) --fix ruff check $(check_dirs) --fix
ruff format $(check_dirs) ruff format $(check_dirs)
test:
CUDA_VISIBLE_DEVICES= pytest tests/

184
README.md
View File

@ -8,9 +8,10 @@
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/9840a653-7e9c-41c8-ae89
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
## Table of Contents ## Table of Contents
@ -45,9 +47,9 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@ -69,14 +71,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary>
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion. [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary>
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details. [24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage. [24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
@ -145,38 +151,38 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Default module | Template | | Model | Model size | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules for better convergence. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
> >
> Remember to use the **SAME** template in training and inference. > Remember to use the **SAME** template in training and inference.
@ -208,6 +214,8 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@ -251,6 +259,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -267,6 +276,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
<details><summary>Preference datasets</summary> <details><summary>Preference datasets</summary>
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@ -286,21 +296,21 @@ huggingface-cli login
| Mandatory | Minimum | Recommend | | Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.41.0 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.2 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### Hardware Requirement ### Hardware Requirement
@ -326,10 +336,10 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@ -350,14 +360,28 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
Join [NPU user group](assets/wechat_npu.jpg). Join [NPU user group](assets/wechat_npu.jpg).
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
| Requirement | Minimum | Recommend | ```bash
| ------------ | ------- | --------- | # replace the url according to your CANN version and devices
| CANN | 8.0.RC1 | 8.0.RC1 | # install CANN Toolkit
| torch | 2.2.0 | 2.2.0 | wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
| torch-npu | 2.2.0 | 2.2.0 | bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
| deepspeed | 0.13.2 | 0.13.2 |
# install CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# set env variables
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| Requirement | Minimum | Recommend |
| ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
Docker image: Docker image:
@ -382,9 +406,9 @@ Please refer to [data/README.md](data/README.md) for checking the details about
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
See [examples/README.md](examples/README.md) for advanced usage (including distributed training). See [examples/README.md](examples/README.md) for advanced usage (including distributed training).
@ -394,36 +418,38 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio)) ### Fine-Tuning with LLaMA Board GUI (powered by [Gradio](https://github.com/gradio-app/gradio))
> [!IMPORTANT]
> LLaMA Board GUI only supports training on a single GPU.
#### Use local environment
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui llamafactory-cli webui
``` ```
</details> ### Build Docker
#### Use Docker #### Use Docker
```bash ```bash
docker build -f ./Dockerfile -t llama-factory:latest . docker build -f ./Dockerfile \
docker run --gpus=all \ --build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -it --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llamafactory \
-d llama-factory:latest llamafactory:latest
``` ```
#### Use Docker Compose #### Use Docker Compose
```bash ```bash
docker compose -f ./docker-compose.yml up -d docker-compose up -d
docker-compose exec llamafactory bash
``` ```
<details><summary>Details about volume</summary> <details><summary>Details about volume</summary>
@ -437,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> Visit https://platform.openai.com/docs/api-reference/chat/create for API document.
### Download from ModelScope Hub ### Download from ModelScope Hub
If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope. If you have trouble with downloading models and datasets from Hugging Face, you can use ModelScope.
@ -448,7 +477,18 @@ If you have trouble with downloading models and datasets from Hugging Face, you
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
``` ```
Train the model by specifying a model ID of the ModelScope Hub as the `--model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments.
```yaml
report_to: wandb
run_name: test_run # optional
```
Set `WANDB_API_KEY` to [your key](https://wandb.ai/authorize) when launching training tasks to log in with your W&B account.
## Projects using LLaMA Factory ## Projects using LLaMA Factory
@ -507,7 +547,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@ -8,9 +8,10 @@
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](https://gallery.pai-ml.com/assets/open-in-dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board) [![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board) [![Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535) [![GitHub Tread](https://trendshift.io/api/badge/repositories/4535)](https://trendshift.io/repositories/4535)
@ -25,6 +26,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
## 目录 ## 目录
@ -45,9 +47,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@ -69,14 +71,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary>
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。 [24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary>
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。 [24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
@ -145,40 +151,40 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | 默认模块 | Template | | 模型名 | 模型大小 | Template |
| -------------------------------------------------------- | -------------------------------- | ----------------- | --------- | | -------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | W_pack | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | query_key_value | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | q_proj,v_proj | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | q_proj,v_proj | deepseek | | [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | query_key_value | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | q_proj,v_proj | llama3 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | q_proj,v_proj | vicuna | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | q_proj,v_proj | mistral | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | q_proj,v_proj | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [PaliGemma](https://huggingface.co/google) | 3B | q_proj,v_proj | gemma | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | qkv_proj | phi | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | q_proj,v_proj | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | q_proj,v_proj | - | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | q_proj,v_proj | yi | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi_vl | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块以取得更好的效果 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 请务必在训练和推理时采用**完全一致**的模板。
>
> 请务必在训练和推理时使用**完全一致**的模板。
项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llamafactory/extras/constants.py)。
@ -208,6 +214,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile) - [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B) - [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [FineWeb (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb)
- [FineWeb-Edu (en)](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack) - [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
@ -251,6 +259,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -267,6 +276,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>偏好数据集</summary> <details><summary>偏好数据集</summary>
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@ -286,21 +296,21 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 | | 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.11 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.3.0 |
| transformers | 4.37.2 | 4.41.0 | | transformers | 4.41.2 | 4.41.2 |
| datasets | 2.14.3 | 2.19.1 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.27.2 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.9.0 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.8.2 | 0.8.6 | | trl | 0.8.6 | 0.9.4 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.14.0 | | deepspeed | 0.10.0 | 0.14.0 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.0 | 0.4.2 | | vllm | 0.4.3 | 0.4.3 |
| flash-attn | 2.3.0 | 2.5.8 | | flash-attn | 2.3.0 | 2.5.9 |
### 硬件依赖 ### 硬件依赖
@ -326,10 +336,10 @@ huggingface-cli login
```bash ```bash
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory cd LLaMA-Factory
pip install -e .[torch,metrics] pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality 可选的额外依赖项torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@ -350,21 +360,35 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
加入 [NPU 用户群](assets/wechat_npu.jpg)。 加入 [NPU 用户群](assets/wechat_npu.jpg)。
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。 在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
| 依赖项 | 至少 | 推荐 | ```bash
| ------------ | ------- | --------- | # 请替换 URL 为 CANN 版本和设备型号对应的 URL
| CANN | 8.0.RC1 | 8.0.RC1 | # 安装 CANN Toolkit
| torch | 2.2.0 | 2.2.0 | wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run
| torch-npu | 2.2.0 | 2.2.0 | bash Ascend-cann-toolkit_8.0.RC1.alpha001_linux-"$(uname -i)".run --install
| deepspeed | 0.13.2 | 0.13.2 |
# 安装 CANN Kernels
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Milan-ASL/Milan-ASL%20V100R001C17SPC701/Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run
bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
| 依赖项 | 至少 | 推荐 |
| ------------ | ------- | ----------- |
| CANN | 8.0.RC1 | 8.0.RC1 |
| torch | 2.1.0 | 2.1.0 |
| torch-npu | 2.1.0 | 2.1.0.post3 |
| deepspeed | 0.13.2 | 0.13.2 |
Docker 镜像: Docker 镜像:
- 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) - 32GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
- 64GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) - 64GB[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false` 如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`
@ -382,9 +406,9 @@ Docker 镜像:
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。 高级用法请参考 [examples/README_zh.md](examples/README_zh.md)(包括多 GPU 微调)。
@ -394,34 +418,38 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_s
### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动) ### LLaMA Board 可视化微调(由 [Gradio](https://github.com/gradio-app/gradio) 驱动)
> [!IMPORTANT]
> LLaMA Board 可视化界面目前仅支持单 GPU 训练。
#### 使用本地环境
```bash ```bash
CUDA_VISIBLE_DEVICES=0 GRADIO_SHARE=1 llamafactory-cli webui llamafactory-cli webui
``` ```
### 构建 Docker
#### 使用 Docker #### 使用 Docker
```bash ```bash
docker build -f ./Dockerfile -t llama-factory:latest . docker build -f ./Dockerfile \
docker run --gpus=all \ --build-arg INSTALL_BNB=false \
--build-arg INSTALL_VLLM=false \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .
docker run -it --gpus=all \
-v ./hf_cache:/root/.cache/huggingface/ \ -v ./hf_cache:/root/.cache/huggingface/ \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-e CUDA_VISIBLE_DEVICES=0 \
-p 7860:7860 \ -p 7860:7860 \
-p 8000:8000 \
--shm-size 16G \ --shm-size 16G \
--name llama_factory \ --name llamafactory \
-d llama-factory:latest llamafactory:latest
``` ```
#### 使用 Docker Compose #### 使用 Docker Compose
```bash ```bash
docker compose -f ./docker-compose.yml up -d docker-compose up -d
docker-compose exec llamafactory bash
``` ```
<details><summary>数据卷详情</summary> <details><summary>数据卷详情</summary>
@ -435,9 +463,12 @@ docker compose -f ./docker-compose.yml up -d
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP]
> API 文档请查阅 https://platform.openai.com/docs/api-reference/chat/create。
### 从魔搭社区下载 ### 从魔搭社区下载
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。 如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
@ -446,7 +477,18 @@ CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/l
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
``` ```
`--model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。
```yaml
report_to: wandb
run_name: test_run # 可选
```
在启动训练任务时,将 `WANDB_API_KEY` 设置为[密钥](https://wandb.ai/authorize)来登录 W&B 账户。
## 使用了 LLaMA Factory 的项目 ## 使用了 LLaMA Factory 的项目
@ -505,7 +547,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

Binary file not shown.

Before

Width:  |  Height:  |  Size: 192 KiB

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

After

Width:  |  Height:  |  Size: 148 KiB

View File

@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
"ranking": "whether the dataset is a preference dataset or not. (default: False)", "ranking": "whether the dataset is a preference dataset or not. (default: False)",
"subset": "the name of the subset. (optional, default: None)", "subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
"columns (optional)": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input)", "query": "the column name in the dataset containing the queries. (default: input)",

View File

@ -12,6 +12,7 @@
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None", "folder": "Hugging Face 仓库的文件夹名称可选默认None",
"num_samples": "该数据集中用于训练的样本数量。可选默认None",
"columns可选": { "columns可选": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",

View File

@ -248,6 +248,10 @@
"ruozhiba_gpt4": { "ruozhiba_gpt4": {
"hf_hub_url": "hfl/ruozhiba_gpt4_turbo" "hf_hub_url": "hfl/ruozhiba_gpt4_turbo"
}, },
"neo_sft": {
"hf_hub_url": "m-a-p/neo_sft_phase2",
"formatting": "sharegpt"
},
"llava_1k_en": { "llava_1k_en": {
"hf_hub_url": "BUAADreamer/llava-en-zh-2k", "hf_hub_url": "BUAADreamer/llava-en-zh-2k",
"subset": "en", "subset": "en",
@ -308,6 +312,20 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"mllm_pt_demo": {
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"oasst_de": { "oasst_de": {
"hf_hub_url": "mayflowergmbh/oasst_de" "hf_hub_url": "mayflowergmbh/oasst_de"
}, },
@ -377,6 +395,16 @@
"rejected": "rejected" "rejected": "rejected"
} }
}, },
"ultrafeedback": {
"hf_hub_url": "llamafactory/ultrafeedback_binarized",
"ms_hub_url": "llamafactory/ultrafeedback_binarized",
"ranking": true,
"columns": {
"prompt": "instruction",
"chosen": "chosen",
"rejected": "rejected"
}
},
"orca_pairs": { "orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs", "hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true, "ranking": true,
@ -434,6 +462,15 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"ultrafeedback_kto": {
"hf_hub_url": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
"ms_hub_url": "AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto",
"columns": {
"prompt": "prompt",
"response": "completion",
"kto_tag": "label"
}
},
"wiki_demo": { "wiki_demo": {
"file_name": "wiki_demo.txt", "file_name": "wiki_demo.txt",
"columns": { "columns": {
@ -487,6 +524,18 @@
"prompt": "text" "prompt": "text"
} }
}, },
"fileweb": {
"hf_hub_url": "HuggingFaceFW/fineweb",
"columns": {
"prompt": "text"
}
},
"fileweb_edu": {
"hf_hub_url": "HuggingFaceFW/fineweb-edu",
"columns": {
"prompt": "text"
}
},
"the_stack": { "the_stack": {
"hf_hub_url": "bigcode/the-stack", "hf_hub_url": "bigcode/the-stack",
"ms_hub_url": "AI-ModelScope/the-stack", "ms_hub_url": "AI-ModelScope/the-stack",

View File

@ -1,20 +1,25 @@
version: '3.8'
services: services:
llama-factory: llamafactory:
build: build:
dockerfile: Dockerfile dockerfile: Dockerfile
context: . context: .
container_name: llama_factory args:
INSTALL_BNB: false
INSTALL_VLLM: false
INSTALL_DEEPSPEED: false
PIP_INDEX: https://pypi.org/simple
container_name: llamafactory
volumes: volumes:
- ./hf_cache:/root/.cache/huggingface/ - ./hf_cache:/root/.cache/huggingface/
- ./data:/app/data - ./data:/app/data
- ./output:/app/output - ./output:/app/output
environment:
- CUDA_VISIBLE_DEVICES=0
ports: ports:
- "7860:7860" - "7860:7860"
- "8000:8000"
ipc: host ipc: host
tty: true
stdin_open: true
command: bash
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets
@ -154,7 +155,7 @@ class MMLU(datasets.GeneratorBasedBuilder):
] ]
def _generate_examples(self, filepath): def _generate_examples(self, filepath):
df = pd.read_csv(filepath) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): for i, instance in enumerate(df.to_dict(orient="records")):

View File

@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
## Table of Contents ## Table of Contents
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) - [LoRA Fine-Tuning](#lora-fine-tuning)
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) - [QLoRA Fine-Tuning](#qlora-fine-tuning)
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) - [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
- [Extras](#extras) - [Extras](#extras)
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
## Examples ## Examples
### LoRA Fine-Tuning on A Single GPU ### LoRA Fine-Tuning
#### (Continuous) Pre-Training #### (Continuous) Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### Reward Modeling #### Reward Modeling
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO Training #### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO Training #### DPO/ORPO/SimPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### KTO Training #### KTO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### Preprocess Dataset #### Preprocess Dataset
@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks #### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
``` ```
### QLoRA Fine-Tuning on a Single GPU #### Supervised Fine-Tuning on Multiple Nodes
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### LoRA Fine-Tuning on Multiple GPUs
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash
bash examples/lora_multi_gpu/single_node.sh
```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes
```bash
bash examples/lora_multi_gpu/multi_node.sh
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### LoRA Fine-Tuning on Multiple NPUs ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### Full-Parameter Fine-Tuning on Multiple GPUs #### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
#### Supervised Fine-Tuning with Accelerate on Single Node
```bash ```bash
bash examples/full_multi_gpu/single_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
``` ```
#### Supervised Fine-Tuning with Accelerate on Multiple Nodes #### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### Full-Parameter Fine-Tuning
#### Supervised Fine-Tuning on Single Node
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### Supervised Fine-Tuning on Multiple Nodes
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
bash examples/full_multi_gpu/predict.sh llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### Quantizing Model using AutoGPTQ #### Quantizing Model using AutoGPTQ
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
#### Use CLI #### Use CLI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### Use Web UI #### Use Web UI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### Launch OpenAI-style API #### Launch OpenAI-style API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### Extras ### Extras
@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### Full-Parameter Fine-Tuning using GaLore #### Full-Parameter Fine-Tuning using GaLore
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using BAdam #### Full-Parameter Fine-Tuning using BAdam
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA Fine-Tuning
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### Mixture-of-Depths Fine-Tuning #### Mixture-of-Depths Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro Fine-Tuning #### LLaMA-Pro Fine-Tuning
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA Fine-Tuning #### FSDP+QLoRA Fine-Tuning
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@ -4,59 +4,59 @@
## 目录 ## 目录
- [单 GPU LoRA 微调](#单-gpu-lora-微调) - [LoRA 微调](#lora-微调)
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调) - [QLoRA 微调](#qlora-微调)
- [多 GPU LoRA 微调](#多-gpu-lora-微调) - [全参数微调](#全参数微调)
- [多 NPU LoRA 微调](#多-npu-lora-微调)
- [多 GPU 全参数微调](#多-gpu-全参数微调)
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
- [推理 LoRA 模型](#推理-lora-模型) - [推理 LoRA 模型](#推理-lora-模型)
- [杂项](#杂项) - [杂项](#杂项)
使用 `CUDA_VISIBLE_DEVICES`GPU`ASCEND_RT_VISIBLE_DEVICES`NPU选择计算设备。
## 示例 ## 示例
### 单 GPU LoRA 微调 ### LoRA 微调
#### (增量)预训练 #### (增量)预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### 指令监督微调 #### 指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO 训练 #### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO 训练 #### DPO/ORPO/SimPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### KTO 训练 #### KTO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### 预处理数据集 #### 预处理数据集
@ -64,93 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### 在 MMLU/CMMLU/C-Eval 上评估 #### 在 MMLU/CMMLU/C-Eval 上评估
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
``` ```
### 单 GPU QLoRA 微调 #### 多机指令监督微调
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### 基于 4 比特 AWQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### 多 GPU LoRA 微调
#### 使用 Accelerate 进行单节点训练
```bash
bash examples/lora_multi_gpu/single_node.sh
```
#### 使用 Accelerate 进行多节点训练
```bash
bash examples/lora_multi_gpu/multi_node.sh
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
bash examples/lora_multi_gpu/ds_zero3.sh FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### 多 NPU LoRA 微调 ### QLoRA 微调
#### 使用 DeepSpeed ZeRO-0 训练 #### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
bash examples/lora_multi_npu/ds_zero0.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### 多 GPU 全参数微调 #### 基于 4/8 比特 GPTQ 量化进行指令监督微调
#### 使用 DeepSpeed 进行单节点训练
```bash ```bash
bash examples/full_multi_gpu/single_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
``` ```
#### 使用 DeepSpeed 进行多节点训练 #### 基于 4 比特 AWQ 量化进行指令监督微调
```bash ```bash
bash examples/full_multi_gpu/multi_node.sh llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### 全参数微调
#### 在单机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
```
#### 在多机上进行指令监督微调
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
bash examples/full_multi_gpu/predict.sh llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@ -160,35 +146,33 @@ bash examples/full_multi_gpu/predict.sh
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### 使用 AutoGPTQ 量化模型 #### 使用 AutoGPTQ 量化模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### 推理 LoRA 模型 ### 推理 LoRA 模型
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
#### 使用命令行接口 #### 使用命令行接口
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### 使用浏览器界面 #### 使用浏览器界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### 启动 OpenAI 风格 API #### 启动 OpenAI 风格 API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### 杂项 ### 杂项
@ -196,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### 使用 GaLore 进行全参数训练 #### 使用 GaLore 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### 使用 BAdam 进行全参数训练 #### 使用 BAdam 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA 微调
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### 深度混合微调 #### 深度混合微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro 微调 #### LLaMA-Pro 微调
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA 微调 #### FSDP+QLoRA 微调
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@ -5,16 +5,16 @@ downcast_bf16: 'no'
fsdp_config: fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_offload_params: true fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: true # offload may affect training speed
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true fsdp_sync_module_states: true
fsdp_use_orig_params: false fsdp_use_orig_params: true
machine_rank: 0 machine_rank: 0
main_training_function: main main_training_function: main
mixed_precision: fp16 mixed_precision: fp16 # or bf16
num_machines: 1 # the number of nodes num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static rdzv_backend: static

View File

@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,16 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1 # the number of nodes
num_processes: 4 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,18 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 1
main_process_ip: 192.168.0.1
main_process_port: 29555
main_training_function: main
mixed_precision: fp16
num_machines: 2 # the number of nodes
num_processes: 8 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -28,14 +28,14 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -6,10 +6,7 @@ quantization_bit: 4
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### ddp
ddp_timeout: 180000000
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -29,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -1,10 +1,6 @@
#!/bin/bash #!/bin/bash
# DO NOT use GPTQ/AWQ model in FSDP+QLoRA # DO NOT use GPTQ/AWQ model in FSDP+QLoRA
pip install "transformers>=4.39.1"
pip install "accelerate>=0.28.0"
pip install "bitsandbytes>=0.43.0"
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \ CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file examples/accelerate/fsdp_config.yaml \ --config_file examples/accelerate/fsdp_config.yaml \
src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml src/train.py examples/extras/fsdp_qlora/llama3_lora_sft.yaml

View File

@ -29,14 +29,14 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -27,14 +27,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
loraplus_lr_ratio: 16.0 loraplus_lr_ratio: 16.0
### dataset ### dataset
@ -26,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -26,14 +26,15 @@ overwrite_output_dir: true
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
optim: paged_adamw_8bit optim: paged_adamw_8bit
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,10 +5,10 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
pissa_init: true
### ddp pissa_iter: 4
ddp_timeout: 180000000 pissa_convert: true
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -27,15 +27,16 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=2
RANK=0
MASTER_ADDR=192.168.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/full_multi_gpu/llama3_full_predict.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/full_multi_gpu/llama3_full_sft.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

View File

@ -1,6 +0,0 @@
#!/bin/bash
# also launch it on slave machine using slave_config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/master_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@ -1,5 +0,0 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--config_file examples/accelerate/single_config.yaml \
src/train.py examples/lora_multi_gpu/llama3_lora_sft.yaml

View File

@ -1,15 +0,0 @@
#!/bin/bash
NPROC_PER_NODE=4
NNODES=1
RANK=0
MASTER_ADDR=127.0.0.1
MASTER_PORT=29500
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 torchrun \
--nproc_per_node $NPROC_PER_NODE \
--nnodes $NNODES \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train.py examples/lora_multi_npu/llama3_lora_sft_ds.yaml

View File

@ -5,9 +5,6 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
@ -28,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: dpo stage: dpo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
pref_beta: 0.1 pref_beta: 0.1
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo] pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
@ -27,14 +27,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.000005 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,8 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: kto stage: kto
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
pref_beta: 0.1
### dataset ### dataset
dataset: kto_en_demo dataset: kto_en_demo
@ -25,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.000005 learning_rate: 5.0e-6
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -6,7 +6,7 @@ reward_model: saves/llama3-8b/lora/reward
stage: ppo stage: ppo
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -26,11 +26,12 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### generate ### generate
max_new_tokens: 512 max_new_tokens: 512

View File

@ -22,3 +22,4 @@ overwrite_output_dir: true
### eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true
ddp_timeout: 180000000

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: pt stage: pt
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: c4_demo dataset: c4_demo
@ -24,14 +24,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: rm stage: rm
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: dpo_en_demo dataset: dpo_en_demo
@ -25,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.00001 learning_rate: 1.0e-5
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -25,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json deepspeed: examples/deepspeed/ds_z0_config.json
### dataset ### dataset
@ -29,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,10 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
@ -29,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo

View File

@ -6,7 +6,7 @@ visual_inputs: true
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: mllm_demo dataset: mllm_demo
@ -26,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: ISTA-DASLab/Meta-Llama-3-8B-Instruct-AQLM-2Bit-1x16
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -25,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-AWQ
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -25,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -6,7 +6,7 @@ quantization_bit: 4
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -26,14 +26,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -5,7 +5,7 @@ model_name_or_path: TechxGenus/Meta-Llama-3-8B-Instruct-GPTQ
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: q_proj,v_proj lora_target: all
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@ -25,14 +25,15 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 8 gradient_accumulation_steps: 8
learning_rate: 0.0001 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_steps: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
evaluation_strategy: steps eval_strategy: steps
eval_steps: 500 eval_steps: 500

View File

@ -1,12 +1,13 @@
transformers>=4.37.2 transformers>=4.41.2
datasets>=2.14.3 datasets>=2.16.0
accelerate>=0.27.2 accelerate>=0.30.1
peft>=0.10.0 peft>=0.11.1
trl>=0.8.1 trl>=0.8.6
gradio>=4.0.0 gradio>=4.0.0
scipy scipy
einops einops
sentencepiece sentencepiece
tiktoken
protobuf protobuf
uvicorn uvicorn
pydantic pydantic

View File

@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Calculates the flops of pre-trained models. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 #
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ # This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire import fire
import torch import torch
@ -17,6 +30,10 @@ def calculate_flops(
seq_length: int = 256, seq_length: int = 256,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r"""
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)

View File

@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. # Copyright 2024 imoneoi and the LlamaFactory team.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 #
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py # This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import Literal from typing import Literal
@ -32,6 +45,10 @@ def calculate_lr(
cutoff_len: int = 1024, # i.e. maximum input length during training cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate, is_mistral: bool = False, # mistral model uses a smaller learning rate,
): ):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@ -1,6 +1,17 @@
# coding=utf-8 # coding=utf-8
# Calculates the ppl on the dataset of the pre-trained models. # Copyright 2024 the LlamaFactory team.
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
@ -56,6 +67,10 @@ def cal_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@ -1,6 +1,17 @@
# coding=utf-8 # coding=utf-8
# Calculates the distribution of the input lengths in the dataset. # Copyright 2024 the LlamaFactory team.
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
@ -19,6 +30,10 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage="sft", stage="sft",

View File

@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 #
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py # This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@ -37,6 +50,10 @@ def block_expansion(
shard_size: Optional[str] = "2GB", shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False, save_safetensors: Optional[bool] = False,
): ):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand) setattr(config, "num_hidden_layers", num_layers + num_expand)
@ -103,11 +120,11 @@ def block_expansion(
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
print("Fine-tune this model with:") print("- Fine-tune this model with:")
print(" --model_name_or_path {} \\".format(output_dir)) print("model_name_or_path: {}".format(output_dir))
print(" --finetuning_type freeze \\") print("finetuning_type: freeze")
print(" --freeze_trainable_layers {} \\".format(num_expand)) print("freeze_trainable_layers: {}".format(num_expand))
print(" --use_llama_pro") print("use_llama_pro: true")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output #
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Licensed under the Apache License, Version 2.0 (the "License");
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2( def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@ -1,7 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Qwen models in the same format as LLaMA2. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_qwen.py --input_dir input --output_dir output #
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen( def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@ -1,14 +1,25 @@
# coding=utf-8 # coding=utf-8
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir #
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py # This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import fire import fire
import torch
import torch.nn as nn
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -17,38 +28,21 @@ if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name)
child_module = getattr(parent_module, child_name)
base_layer = getattr(child_module, "base_layer")
weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.")
def quantize_loftq( def quantize_loftq(
model_name_or_path: str, model_name_or_path: str,
save_dir: str, output_dir: str,
loftq_bits: Optional[int] = 4, loftq_bits: int = 4,
loftq_iter: Optional[int] = 1, loftq_iter: int = 4,
lora_alpha: Optional[int] = None, lora_alpha: int = None,
lora_rank: Optional[int] = 16, lora_rank: int = 16,
lora_target: Optional[str] = "q_proj,v_proj", lora_dropout: float = 0,
save_safetensors: Optional[bool] = False, lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
): ):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
@ -57,25 +51,34 @@ def quantize_loftq(
inference_mode=True, inference_mode=True,
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=0.1, lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")], target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq", init_lora_weights="loftq",
loftq_config=loftq_config, loftq_config=loftq_config,
) )
# Init LoftQ model # Init LoftQ model
lora_model = get_peft_model(model, lora_config) print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
base_model: "PreTrainedModel" = lora_model.get_base_model() peft_model = get_peft_model(model, lora_config)
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model # Save LoftQ model
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
# Save base model # Save base model
unwrap_model(base_model) base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(save_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__": if __name__ == "__main__":

82
scripts/pissa_init.py Normal file
View File

@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 4,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter)
)
# Init PiSSA model
peft_model = get_peft_model(model, lora_config)
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")
print("- and optionally with:")
print("quantization_bit: 4")
if __name__ == "__main__":
fire.Fire(quantize_pissa)

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from typing import Sequence from typing import Sequence
@ -20,7 +35,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="0", api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
) )
tools = [ tools = [

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import re import re
@ -5,7 +19,7 @@ from setuptools import find_packages, setup
def get_version(): def get_version():
with open(os.path.join("src", "llamafactory", "cli.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@ -21,18 +35,19 @@ def get_requires():
extra_require = { extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.0"], "vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam"], "badam": ["badam"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["tiktoken", "transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"quality": ["ruff"], "dev": ["ruff", "pytest"],
} }

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import uvicorn import uvicorn

View File

@ -1,4 +1,18 @@
# Level: api, webui > chat, eval, train > data, model > extras, hparams # Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION from .cli import VERSION

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional

View File

@ -1,10 +1,27 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import json import json
import os
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@ -25,7 +42,17 @@ if is_fastapi_available():
from fastapi import HTTPException, status from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@ -40,7 +67,9 @@ ROLE_MAPPING = {
} }
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:
@ -49,12 +78,13 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content system = request.messages.pop(0).content
else: else:
system = "" system = None
if len(request.messages) % 2 == 0: if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -66,6 +96,21 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
arguments = message.tool_calls[0].function.arguments arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if image_url.startswith("data:image"): # base64 image
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image = Image.open(image_path).convert("RGB")
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@ -76,9 +121,9 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = "" tools = None
return input_messages, system, tools return input_messages, system, tools, image
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@ -97,11 +142,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -145,7 +191,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools, image = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -159,6 +205,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
@ -56,9 +70,19 @@ class FunctionCall(BaseModel):
function: Function function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[str] = None content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[List[FunctionCall]] = None

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_engine import BaseEngine from .base_engine import BaseEngine
from .chat_model import ChatModel from .chat_model import ChatModel

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union

View File

@ -1,3 +1,20 @@
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import os import os
@ -8,6 +22,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -23,6 +38,9 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
def __init__( def __init__(
self, self,
@ -79,6 +97,7 @@ class HuggingfaceEngine(BaseEngine):
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)
@ -92,7 +111,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
@ -132,6 +151,7 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict( gen_kwargs = dict(
inputs=inputs, inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )

View File

@ -1,19 +1,37 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,8 +53,6 @@ class VllmEngine(BaseEngine):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
infer_dtype = str(infer_dtype).split(".")[-1]
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@ -50,7 +66,7 @@ class VllmEngine(BaseEngine):
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
"trust_remote_code": True, "trust_remote_code": True,
"download_dir": model_args.cache_dir, "download_dir": model_args.cache_dir,
"dtype": infer_dtype, "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen, "max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1, "tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util, "gpu_memory_utilization": model_args.vllm_gpu_util,
@ -70,7 +86,6 @@ class VllmEngine(BaseEngine):
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
@ -109,7 +124,10 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else: else:
multi_modal_data = None multi_modal_data = None
@ -158,12 +176,10 @@ class VllmEngine(BaseEngine):
) )
result_generator = self.model.generate( result_generator = self.model.generate(
prompt=None, inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_ids,
lora_request=self.lora_request, lora_request=self.lora_request,
multi_modal_data=multi_modal_data,
) )
return result_generator return result_generator

View File

@ -1,9 +1,30 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import sys import sys
from enum import Enum, unique from enum import Enum, unique
from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -23,8 +44,6 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
VERSION = "0.7.2.dev0"
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
@ -37,11 +56,14 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__)
@unique @unique
class Command(str, Enum): class Command(str, Enum):
API = "api" API = "api"
CHAT = "chat" CHAT = "chat"
ENV = "env"
EVAL = "eval" EVAL = "eval"
EXPORT = "export" EXPORT = "export"
TRAIN = "train" TRAIN = "train"
@ -57,12 +79,35 @@ def main():
run_api() run_api()
elif command == Command.CHAT: elif command == Command.CHAT:
run_chat() run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL: elif command == Command.EVAL:
run_eval() run_eval()
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
run_exp() force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
else:
run_exp()
elif command == Command.WEBDEMO: elif command == Command.WEBDEMO:
run_web_demo() run_web_demo()
elif command == Command.WEBUI: elif command == Command.WEBUI:

View File

@ -1,16 +1,30 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
from .utils import Role, split_dataset
__all__ = [ __all__ = [
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",
"templates",
"Role", "Role",
"split_dataset", "split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
] ]

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
@ -5,11 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features from datasets import Features
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import Role from .data_utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@ -175,7 +190,10 @@ def convert_sharegpt(
def align_dataset( def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
@ -208,7 +226,7 @@ def align_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset", desc="Converting format of dataset",
) )

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@ -1,24 +1,38 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import os import os
import sys import sys
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer from .template import get_template_and_fix_tokenizer
from .utils import merge_dataset
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@ -31,6 +45,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
@ -61,9 +76,9 @@ def load_single_dataset(
raise ValueError("File {} not found.".format(local_path)) raise ValueError("File {} not found.".format(local_path))
if data_path is None: if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.") raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
else: else:
raise NotImplementedError raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
try: try:
@ -106,18 +121,30 @@ def load_single_dataset(
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset if dataset_attr.num_samples is not None and not data_args.streaming:
num_samples = min(data_args.max_samples, len(dataset)) target_num = dataset_attr.num_samples
dataset = dataset.select(range(num_samples)) indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
return align_dataset(dataset, dataset_attr, data_args) assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset( def get_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
@ -144,7 +171,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.") raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
@ -156,7 +184,7 @@ def get_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )
@ -166,7 +194,7 @@ def get_dataset(
if training_args.should_save: if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path) dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path)) logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0) sys.exit(0)

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -20,11 +34,12 @@ class DatasetAttr:
""" basic configs """ """ basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """ """ extra configs """
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: bool = False num_samples: Optional[int] = None
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" common columns """ """ common columns """
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
@ -102,10 +117,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("num_samples", dataset_info[name])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template
@ -23,7 +36,7 @@ if TYPE_CHECKING:
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],

View File

@ -1,13 +1,26 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional # Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template
@ -16,6 +29,55 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset( def preprocess_feedback_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -45,50 +107,17 @@ def preprocess_feedback_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] prompt=examples["prompt"][i],
response=examples["response"][i],
if examples["response"][i][0]["content"]: # desired example kl_response=kl_response[i],
kto_tag = True system=examples["system"][i],
messages = examples["prompt"][i] + [examples["response"][i][0]] tools=examples["tools"][i],
else: # undesired example template=template,
kto_tag = False tokenizer=tokenizer,
messages = examples["prompt"][i] + [examples["response"][i][1]] processor=processor,
data_args=data_args,
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)

View File

@ -1,27 +0,0 @@
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
# get paligemma token type ids for computing loss
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@ -1,13 +1,26 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional # Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template
@ -16,6 +29,44 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset( def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -43,40 +94,16 @@ def preprocess_pairwise_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] prompt=examples["prompt"][i],
response=examples["response"][i],
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] system=examples["system"][i],
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] tools=examples["tools"][i],
prompt_ids, chosen_ids = template.encode_oneturn( template=template,
tokenizer, tokenizer=tokenizer,
chosen_messages, processor=processor,
examples["system"][i], data_args=data_args,
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
_, rejected_ids = template.encode_oneturn(
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
model_inputs["chosen_input_ids"].append(chosen_input_ids) model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids)) model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels) model_inputs["chosen_labels"].append(chosen_labels)

View File

@ -1,9 +1,26 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
@ -12,13 +29,14 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@ -0,0 +1,78 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)

View File

@ -1,13 +1,27 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional # Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template
@ -16,6 +30,48 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -36,41 +92,16 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models input_ids, labels = _encode_supervised_example(
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] prompt=examples["prompt"][i],
response=examples["response"][i],
messages = examples["prompt"][i] + examples["response"][i] system=examples["system"][i],
input_ids, labels = [], [] tools=examples["tools"][i],
template=template,
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models tokenizer=tokenizer,
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) processor=processor,
input_ids += [image_token_id] * getattr(processor, "image_seq_length") data_args=data_args,
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") )
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
@ -90,41 +121,55 @@ def preprocess_packed_supervised_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} valid_num = 0
input_ids, labels = [], [] batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = _encode_supervised_example(
for source_ids, target_ids in template.encode_multiturn( prompt=examples["prompt"][i],
tokenizer, messages, examples["system"][i], examples["tools"][i] response=examples["response"][i],
): system=examples["system"][i],
if data_args.train_on_prompt: tools=examples["tools"][i],
source_mask = source_ids template=template,
elif len(input_ids) != 0 and template.efficient_eos: tokenizer=tokenizer,
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) processor=None,
else: data_args=data_args,
source_mask = [IGNORE_INDEX] * len(source_ids) )
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
valid_num += 1
input_ids += source_ids + target_ids model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
labels += source_mask + target_ids knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_labels = [], []
for length in knapsack:
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
if template.efficient_eos: if len(packed_input_ids) < data_args.cutoff_len:
input_ids += [tokenizer.eos_token_id] pad_length = data_args.cutoff_len - len(packed_input_ids)
labels += [tokenizer.eos_token_id] packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
total_length = len(input_ids) if len(packed_input_ids) != data_args.cutoff_len:
block_size = data_args.cutoff_len raise ValueError("The length of packed example should be identical to the cutoff length.")
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size model_inputs["input_ids"].append(packed_input_ids)
# split by chunks of cutoff_len model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
for i in range(0, total_length, block_size): model_inputs["labels"].append(packed_labels)
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
model_inputs["input_ids"].append(input_ids[i : i + block_size])
model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i : i + block_size])
return model_inputs return model_inputs

View File

@ -1,13 +1,26 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional # Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..utils import Role from ..data_utils import Role
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template
@ -16,6 +29,37 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
return input_ids, labels
def preprocess_unsupervised_dataset( def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -35,30 +79,16 @@ def preprocess_unsupervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models input_ids, labels = _encode_unsupervised_example(
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] prompt=examples["prompt"][i],
response=examples["response"][i],
if len(examples["response"][i]) == 1: system=examples["system"][i],
messages = examples["prompt"][i] + examples["response"][i] tools=examples["tools"][i],
else: template=template,
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}] tokenizer=tokenizer,
processor=processor,
input_ids, labels = template.encode_oneturn( data_args=data_args,
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)

View File

@ -1,9 +1,23 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .utils import Role, infer_max_len
if TYPE_CHECKING: if TYPE_CHECKING:
@ -196,7 +210,7 @@ class Llama2Template(Template):
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
templates: Dict[str, Template] = {} TEMPLATES: Dict[str, Template] = {}
def _register_template( def _register_template(
@ -248,7 +262,7 @@ def _register_template(
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
templates[name] = template_class( TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter, format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter, format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter, format_system=format_system or default_user_formatter,
@ -348,9 +362,9 @@ def get_template_and_fix_tokenizer(
name: Optional[str] = None, name: Optional[str] = None,
) -> Template: ) -> Template:
if name is None: if name is None:
template = templates["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = templates.get(name, None) template = TEMPLATES.get(name, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError("Template {} does not exist.".format(name))
@ -544,8 +558,13 @@ _register_template(
) )
] ]
), ),
format_system=EmptyFormatter(slots=[{"bos_token"}]), format_system=StringFormatter(
force_system=True, slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]
),
default_system=(
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
"by providing thorough responses. You are trained by Cohere."
),
) )
@ -653,6 +672,19 @@ _register_template(
) )
_register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@ -682,17 +714,8 @@ _register_template(
_register_template( _register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
) )
@ -742,7 +765,6 @@ _register_template(
_register_template( _register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True, force_system=True,
) )
@ -751,12 +773,28 @@ _register_template(
_register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True, force_system=True,
) )
_register_template(
name="openchat-3.6",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
stop_words=["<|eot_id|>"],
replace_eos=True,
force_system=True,
)
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@ -807,6 +845,15 @@ _register_template(
) )
_register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
stop_words=["<_end>"],
replace_eos=True,
)
_register_template( _register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@ -857,6 +904,7 @@ _register_template(
_register_template( _register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,

View File

@ -1,4 +1,41 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Copyright 2024 the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2020 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import inspect import inspect
import json import json
@ -26,9 +63,7 @@ class Evaluator:
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:

View File

@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
@ -10,7 +24,6 @@ class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r""" r"""
@ -42,8 +55,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
@ -56,8 +69,7 @@ _register_eval_template(
name="en", name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: ", answer="\nAnswer:",
prefix=" ",
) )
@ -66,5 +78,4 @@ _register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix=" ",
) )

Some files were not shown because too many files have changed in this diff Show More