Merge branch 'main' into main

Former-commit-id: 5f14910910154ba569435e7e68acbd6c30f79e80
This commit is contained in:
hoshi-hiyouga 2024-11-02 21:20:27 +08:00 committed by GitHub
commit d99e164cad
147 changed files with 3087 additions and 1833 deletions

View File

@ -7,6 +7,8 @@ data
docker docker
saves saves
hf_cache hf_cache
ms_cache
om_cache
output output
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@ -1,32 +1,35 @@
# Note: actually we do not support .env, just for reference # Note: actually we do not support .env, just for reference
# api # api
API_HOST=0.0.0.0 API_HOST=
API_PORT=8000 API_PORT=
API_KEY= API_KEY=
API_MODEL_NAME=gpt-3.5-turbo API_MODEL_NAME=
FASTAPI_ROOT_PATH= FASTAPI_ROOT_PATH=
MAX_CONCURRENT=
# general # general
DISABLE_VERSION_CHECK= DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS= FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY= LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
RECORD_VRAM= RECORD_VRAM=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=
MASTER_ADDR= MASTER_ADDR=
MASTER_PORT= MASTER_PORT=
NNODES= NNODES=
RANK= NODE_RANK=
NPROC_PER_NODE= NPROC_PER_NODE=
# wandb # wandb
WANDB_DISABLED= WANDB_DISABLED=
WANDB_PROJECT=huggingface WANDB_PROJECT=
WANDB_API_KEY= WANDB_API_KEY=
# gradio ui # gradio ui
GRADIO_SHARE=False GRADIO_SHARE=
GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_NAME=
GRADIO_SERVER_PORT= GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH= GRADIO_ROOT_PATH=
GRADIO_IPV6=
# setup # setup
ENABLE_SHORT_CONSOLE=1 ENABLE_SHORT_CONSOLE=1
# reserved (do not use) # reserved (do not use)

View File

@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
### Style guide ### Style guide
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details. LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
### Create a Pull Request
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
git clone git@github.com:[username]/LLaMA-Factory.git
cd LLaMA-Factory
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
```
3. Create a new branch to hold your development changes:
```bash
git checkout -b dev_your_branch
```
4. Set up a development environment by running the following command in a virtual environment:
```bash
pip install -e ".[dev]"
```
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
5. Check code before commit:
```bash
make commit
make style && make quality
make test
```
6. Submit changes:
```bash
git add .
git commit -m "commit message"
git fetch upstream
git rebase upstream/main
git push -u origin dev_your_branch
```
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).

View File

@ -22,7 +22,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: python-version:
- "3.8" - "3.8" # TODO: remove py38 in next transformers release
- "3.9" - "3.9"
- "3.10" - "3.10"
- "3.11" - "3.11"
@ -54,7 +54,6 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install git+https://github.com/huggingface/transformers.git
python -m pip install ".[torch,dev]" python -m pip install ".[torch,dev]"
- name: Check quality - name: Check quality

1
.gitignore vendored
View File

@ -162,6 +162,7 @@ cython_debug/
# custom .gitignore # custom .gitignore
ms_cache/ ms_cache/
hf_cache/ hf_cache/
om_cache/
cache/ cache/
config/ config/
saves/ saves/

28
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,28 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-ast
- id: check-added-large-files
args: ['--maxkb=25000']
- id: check-merge-conflict
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@ -1,7 +1,14 @@
.PHONY: quality style test .PHONY: build commit quality style test
check_dirs := scripts src tests setup.py check_dirs := scripts src tests setup.py
build:
pip install build && python -m build
commit:
pre-commit install
pre-commit run --all-files
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
ruff format --check $(check_dirs) ruff format --check $(check_dirs)
@ -11,4 +18,4 @@ style:
ruff format $(check_dirs) ruff format $(check_dirs)
test: test:
CUDA_VISIBLE_DEVICES= pytest tests/ CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/

104
README.md
View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-92-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -26,10 +26,17 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/ - **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
Recent activities:
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
## Table of Contents ## Table of Contents
- [Features](#features) - [Features](#features)
@ -72,6 +79,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. [24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
@ -130,7 +139,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
@ -162,36 +171,39 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi-small | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@ -360,7 +372,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@ -412,7 +424,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
### Data Preparation ### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
> [!NOTE] > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. > Please update `data/dataset_info.json` to use your custom dataset.
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v ./saves:/app/saves \ -v ./saves:/app/saves \
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. - `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users. - `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. - `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine. - `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP] > [!TIP]
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document. > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
### Download from ModelScope Hub ### Download from ModelScope Hub
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Download from Modelers Hub
You can also use Modelers Hub to download models and datasets.
```bash
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
```
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
### Use W&B Logger ### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
@ -684,11 +712,13 @@ If you have a project that should be incorporated, please contact via email or c
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@ -696,7 +726,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-91-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-92-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -26,11 +26,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **入门教程**https://zhuanlan.zhihu.com/p/695287607 - **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/ - **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
近期活动:
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
## 目录 ## 目录
- [项目特色](#项目特色) - [项目特色](#项目特色)
@ -73,6 +80,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
@ -163,35 +172,38 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@ -360,7 +372,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@ -412,7 +424,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
### 数据准备 ### 数据准备
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v ./saves:/app/saves \ -v ./saves:/app/saves \
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。 - `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。 - `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
- `data`:宿主机中存放数据集的文件夹路径。 - `data`:宿主机中存放数据集的文件夹路径。
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。 - `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP] > [!TIP]
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。 > API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
### 从魔搭社区下载 ### 从魔搭社区下载
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 从魔乐社区下载
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
```bash
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
```
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`
### 使用 W&B 面板 ### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。 若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
@ -684,11 +712,12 @@ run_name: test_run # 可选
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。 1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@ -696,7 +725,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 149 KiB

After

Width:  |  Height:  |  Size: 132 KiB

View File

@ -17,9 +17,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
_LICENSE = "gpl-3.0" _LICENSE = "gpl-3.0"
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = [] conversations = []

View File

@ -54,7 +54,8 @@
}, },
"alpaca_en": { "alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en", "hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en" "ms_hub_url": "llamafactory/alpaca_en",
"om_hub_url": "HaM/alpaca_en"
}, },
"alpaca_zh": { "alpaca_zh": {
"hf_hub_url": "llamafactory/alpaca_zh", "hf_hub_url": "llamafactory/alpaca_zh",
@ -66,7 +67,8 @@
}, },
"alpaca_gpt4_zh": { "alpaca_gpt4_zh": {
"hf_hub_url": "llamafactory/alpaca_gpt4_zh", "hf_hub_url": "llamafactory/alpaca_gpt4_zh",
"ms_hub_url": "llamafactory/alpaca_gpt4_zh" "ms_hub_url": "llamafactory/alpaca_gpt4_zh",
"om_hub_url": "State_Cloud/alpaca-gpt4-data-zh"
}, },
"glaive_toolcall_en": { "glaive_toolcall_en": {
"hf_hub_url": "llamafactory/glaive_toolcall_en", "hf_hub_url": "llamafactory/glaive_toolcall_en",

View File

@ -8,9 +8,9 @@ import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit" _LICENSE = "mit"
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = { _URLS = {
"train": [ "train": [
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
data = json.loads(row) data = json.loads(row)
chosen = data["chosen"] chosen = data["chosen"]

View File

@ -20,9 +20,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0" _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT) _BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)

View File

@ -1,6 +1,7 @@
# Use the NVIDIA official image with PyTorch 2.3.0 # Default use the NVIDIA official image with PyTorch 2.3.0
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
FROM nvcr.io/nvidia/pytorch:24.02-py3 ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
FROM ${BASE_IMAGE}
# Define environments # Define environments
ENV MAX_JOBS=4 ENV MAX_JOBS=4
@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory # Set the working directory
@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]" pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention # Rebuild flash attention

View File

@ -8,11 +8,15 @@ services:
INSTALL_VLLM: false INSTALL_VLLM: false
INSTALL_DEEPSPEED: false INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
INSTALL_EETQ: false
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
ports: ports:

View File

@ -10,6 +10,7 @@ services:
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
- /usr/local/dcmi:/usr/local/dcmi - /usr/local/dcmi:/usr/local/dcmi

View File

@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory # Set the working directory
@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]" pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention # Rebuild flash attention

View File

@ -8,11 +8,14 @@ services:
INSTALL_VLLM: false INSTALL_VLLM: false
INSTALL_DEEPSPEED: false INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
- ../../saves:/app/saves - ../../saves:/app/saves

View File

@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
df = pd.read_csv(filepath, header=None) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): yield from enumerate(df.to_dict(orient="records"))
yield i, instance

View File

@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)

View File

@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### 多机指令监督微调 #### 多机指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存

View File

@ -1,9 +1,9 @@
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.30.1,<=0.34.2 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
gradio>=4.0.0 gradio>=4.0.0,<5.0.0
pandas>=2.0.0 pandas>=2.0.0
scipy scipy
einops einops
@ -19,3 +19,4 @@ fire
packaging packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0
av

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# #
# This code is inspired by the Microsoft's DeepSpeed library. # This code is inspired by the Microsoft's DeepSpeed library.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team. # Copyright 2024 imoneoi and the LlamaFactory team.
# #
# This code is inspired by the imoneoi's OpenChat library. # This code is inspired by the imoneoi's OpenChat library.
@ -74,7 +73,7 @@ def calculate_lr(
elif stage == "sft": elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0 valid_tokens, total_tokens = 0, 0

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
elif "4090" in device_name: elif "4090" in device_name:
return 98 * 1e12 * world_size return 98 * 1e12 * world_size
else: else:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu( def calculate_mfu(
@ -140,10 +139,10 @@ def calculate_mfu(
"bf16": True, "bf16": True,
} }
if deepspeed_stage in [2, 3]: if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args) run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f) result = json.load(f)
if dist.is_initialized(): if dist.is_initialized():
@ -157,7 +156,7 @@ def calculate_mfu(
* compute_model_flops(model_name_or_path, total_batch_size, seq_length) * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size) / compute_device_flops(world_size)
) )
print("MFU: {:.2f}%".format(mfu_value * 100)) print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def calculate_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
) )
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
@ -125,8 +124,8 @@ def calculate_ppl(
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2) json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities))) print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print("Perplexities have been saved at {}.".format(save_name)) print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -61,7 +60,7 @@ def length_cdf(
for length, count in length_tuples: for length, count in length_tuples:
count_accu += count count_accu += count
prob_accu += count / total_num * 100 prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval)) print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# #
# This code is inspired by the Tencent's LLaMA-Pro library. # This code is inspired by the Tencent's LLaMA-Pro library.
@ -40,7 +39,7 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str: def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index)) return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion( def block_expansion(
@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict() state_dict = model.state_dict()
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand)) raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
output_state_dict = OrderedDict() output_state_dict = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1 layer_cnt += 1
if (i + 1) % split == 0: if (i + 1) % split == 0:
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key: if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else: else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1 layer_cnt += 1
for key, value in state_dict.items(): for key, value in state_dict.items():
@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze") print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand)) print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true") print("use_llama_pro: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2( def llamafy_baichuan2(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key: elif "lm_head" in key:
llama2_state_dict[key] = value llama2_state_dict[key] = value
else: else:
raise KeyError("Unable to process key {}".format(key)) raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "") return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: Dict[str, Any] = OrderedDict()
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen( def llamafy_qwen(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -70,19 +69,19 @@ def quantize_loftq(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir)) print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(loftq_dir)) print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits)) print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
target_modules=lora_target, target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
) )
# Init PiSSA model # Init PiSSA model
@ -65,17 +64,17 @@ def quantize_pissa(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir)) print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(pissa_dir)) print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("pissa_init: false") print("pissa_init: false")
print("pissa_convert: true") print("pissa_convert: true")

65
scripts/test_image.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
messages = []
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "Output the color and number of each box."},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 1:", result.choices[0].message.content)
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "What kind of flower is this?"},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 2:", result.choices[0].message.content)
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
if __name__ == "__main__":
main()

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str: def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]: def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
@ -54,13 +54,14 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.0"], "vllm": ["vllm>=0.4.3,<=0.6.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"], "adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
} }
@ -71,7 +72,7 @@ def main():
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",

View File

@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -20,17 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.30.1,<=0.34.2 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
""" """
from .extras.env import VERSION from .extras.env import VERSION

View File

@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "") root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY", None) api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo")) model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None: def run_api() -> None:
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -21,7 +21,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras import logging
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__) logger = logging.get_logger(__name__)
ROLE_MAPPING = { ROLE_MAPPING = {
Role.USER: DataRole.USER.value, Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value, Role.ASSISTANT: DataRole.ASSISTANT.value,
@ -69,8 +69,8 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None images = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -111,7 +111,7 @@ def _process_request(
else: # web uri else: # web uri
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB") images.append(Image.open(image_stream).convert("RGB"))
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@ -124,7 +124,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, image return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response( async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
@ -193,8 +193,8 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response( async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex) score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")

View File

@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""

View File

@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Gets a list of responses of the chat model. Gets a list of responses of the chat model.
""" """
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
) )
return task.result() return task.result()
@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Asynchronously gets a list of responses of the chat model. Asynchronously gets a list of responses of the chat model.
""" """
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""
Gets the response token-by-token of the chat model. Gets the response token-by-token of the chat model.
""" """
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
Asynchronously gets the response token-by-token of the chat model. Asynchronously gets the response token-by-token of the chat model.
""" """
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning("There is no current event loop, creating a new one.") logger.warning_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None: if images is not None:
mm_input_dict.update({"images": [image], "imglens": [1]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None: if videos is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]}) mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]: if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages( messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], processor
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:

View File

@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine): class VllmEngine(BaseEngine):
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None: if images is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features if images is not None: # add image features
if not isinstance(image, (str, ImageObject)): image_data = []
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image))) for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image).convert("RGB") image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image} image_data.append(image)
multi_modal_data = {"image": image_data}
else: else:
multi_modal_data = None multi_modal_data = None
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text

View File

@ -22,8 +22,8 @@ from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -47,7 +47,7 @@ USAGE = (
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) + f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION)) + " " * (21 - len(VERSION))
+ "|\n|" + "|\n|"
+ " " * 56 + " " * 56
@ -56,7 +56,7 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
@ -86,19 +86,19 @@ def main():
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format( ).format(
nnodes=os.environ.get("NNODES", "1"), nnodes=os.getenv("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"), node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
file_name=launcher.__file__, file_name=launcher.__file__,
@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}.".format(command)) raise NotImplementedError(f"Unknown command: {command}.")

View File

@ -16,7 +16,7 @@ import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
@ -29,45 +29,51 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _convert_images( def _convert_images(
images: Sequence["ImageInput"], images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["ImageInput"]]: ) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
if len(images) == 0: if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None return None
else:
images = images[:]
images = images[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)): for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])): if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i]) images[i] = os.path.join(data_args.image_dir, images[i])
return images return images
def _convert_videos( def _convert_videos(
videos: Sequence["VideoInput"], videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["VideoInput"]]: ) -> Optional[List["VideoInput"]]:
r""" r"""
Optionally concatenates video path to dataset dir when loading from local disk. Optionally concatenates video path to dataset dir when loading from local disk.
""" """
if len(videos) == 0: if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None return None
else:
videos = videos[:]
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)): for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])): if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i]) videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos return videos
@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
@ -171,7 +177,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
@ -192,7 +198,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1] chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True broken_data = True
prompt = aligned_messages prompt = aligned_messages
@ -205,7 +211,7 @@ def convert_sharegpt(
response = aligned_messages[-1:] response = aligned_messages[-1:]
if broken_data: if broken_data:
logger.warning("Skipping this abnormal example.") logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], [] prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)

View File

@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features return features
@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"): for key in ("chosen", "rejected"):
for feature in features: for feature in features:
target_feature = { target_feature = {
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature[f"{key}_attention_mask"],
"labels": feature["{}_labels".format(key)], "labels": feature[f"{key}_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
} }

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.") logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy)) raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset( def split_dataset(

View File

@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str): if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value)) raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
functions = [] raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -20,8 +20,8 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
@ -39,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _load_single_dataset( def _load_single_dataset(
@ -51,9 +51,9 @@ def _load_single_dataset(
r""" r"""
Loads a single dataset and aligns it to the standard format. Loads a single dataset and aligns it to the standard format.
""" """
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File {} not found.".format(local_path)) raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
@ -98,10 +97,27 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else: else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
@ -111,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
@ -128,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
@ -277,8 +290,8 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)

View File

@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np import numpy as np
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@ -110,7 +111,7 @@ class BasePlugin:
image = Image.open(image["path"]) image = Image.open(image["path"])
if not isinstance(image, ImageObject): if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs)) results.append(self._preprocess_image(image, **kwargs))
@ -157,6 +158,7 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key input_dict = {"images": None} # default key
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -174,10 +176,16 @@ class BasePlugin:
) )
input_dict["videos"] = videos input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: mm_inputs = {}
return image_processor(**input_dict, return_tensors="pt") if image_processor != video_processor:
else: if input_dict.get("images") is not None:
return {} mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
return mm_inputs
def process_messages( def process_messages(
self, self,
@ -218,6 +226,14 @@ class BasePlugin:
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r""" r"""
Builds batched multimodal inputs for VLMs. Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
seqlens: number of tokens in each sample, shape (batch_size,)
processor: a processor for pre-processing images and videos
""" """
self._validate_input(images, videos) self._validate_input(images, videos)
return {} return {}
@ -245,7 +261,123 @@ class LlavaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
res = self._get_mm_inputs(images, videos, processor)
return res
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -284,7 +416,7 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", "") message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -324,6 +456,68 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override @override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@ -369,7 +563,7 @@ class Qwen2vlPlugin(BasePlugin):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw): if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
@ -382,7 +576,7 @@ class Qwen2vlPlugin(BasePlugin):
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, VIDEO_PLACEHOLDER,
@ -396,10 +590,73 @@ class Qwen2vlPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
exist_images = "pixel_values_images" in mm_inputs
exist_videos = "pixel_values_videos" in mm_inputs
if exist_videos or exist_images:
if exist_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if exist_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
for message in messages:
content = message["content"]
while self.image_token in content:
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1)
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
content = content.replace("{{image}}", self.image_token * image_seqlen)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
return messages return messages
@ -420,8 +677,12 @@ class Qwen2vlPlugin(BasePlugin):
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin, "paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin, "qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
} }
@ -432,6 +693,6 @@ def get_mm_plugin(
) -> "BasePlugin": ) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None) plugin_class = PLUGINS.get(name, None)
if plugin_class is None: if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name)) raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token) return plugin_class(image_token, video_token)

View File

@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
@ -30,7 +30,7 @@ class DatasetAttr:
""" """
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False ranking: bool = False
@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: List["DatasetAttr"] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub" if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name] has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url: if has_hf_url or has_ms_url or has_om_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url): if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else: else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_feedback_example( def _encode_feedback_example(
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0: if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.") logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs return model_inputs

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_pairwise_example( def _encode_pairwise_example(
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"])) print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"])) print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen from .processor_utils import greedy_knapsack, infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_supervised_example( def _encode_supervised_example(
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else: else:
lengths.append(length) lengths.append(length)
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@ -15,7 +15,7 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_unsupervised_example( def _encode_unsupervised_example(
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1: if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from typing_extensions import override from typing_extensions import override
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
@ -32,7 +32,7 @@ if TYPE_CHECKING:
from .mm_plugin import BasePlugin from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
@ -49,6 +49,7 @@ class Template:
stop_words: List[str] stop_words: List[str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
@ -146,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids return token_ids
@ -214,6 +215,7 @@ def _register_template(
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
@ -263,6 +265,7 @@ def _register_template(
stop_words=stop_words, stop_words=stop_words,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added: if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else: else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
@ -353,24 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(data_args.template, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template)) raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None: if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format)) logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
@ -388,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
try: if tokenizer.chat_template is None or template.replace_jinja_template:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) try:
except ValueError: tokenizer.chat_template = _get_jinja_template(template, tokenizer)
logger.info("Cannot add this chat template to tokenizer.") except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template return template
@ -640,6 +641,14 @@ _register_template(
) )
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@ -664,6 +673,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=False,
) )
@ -681,6 +691,14 @@ _register_template(
) )
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
@ -740,6 +758,7 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -754,6 +773,107 @@ _register_template(
) )
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@ -831,6 +951,14 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -840,6 +968,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -852,6 +981,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
@ -907,6 +1037,17 @@ _register_template(
) )
_register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),

View File

@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils": def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None) tool_utils = TOOLS.get(name, None)
if tool_utils is None: if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name)) raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils return tool_utils

View File

@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items() for category_name, category_correct in category_corrects.items()
if len(category_correct) if len(category_correct)
] ]

View File

@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, f"Template {name} does not exist."
return eval_template return eval_template

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>" IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>" VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
@ -107,6 +108,7 @@ VISION_MODELS = set()
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
OPENMIND = "om"
def register_model_group( def register_model_group(
@ -114,17 +116,12 @@ def register_model_group(
template: Optional[str] = None, template: Optional[str] = None,
vision: bool = False, vision: bool = False,
) -> None: ) -> None:
prefix = None
for name, path in models.items(): for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path SUPPORTED_MODELS[name] = path
if template is not None: if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
DEFAULT_TEMPLATE[prefix] = template DEFAULT_TEMPLATE[name] = template
if vision: if vision:
VISION_MODELS.add(prefix) VISION_MODELS.add(name)
register_model_group( register_model_group(
@ -168,14 +165,17 @@ register_model_group(
"Baichuan2-13B-Base": { "Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
}, },
"Baichuan2-7B-Chat": { "Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
}, },
"Baichuan2-13B-Chat": { "Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
}, },
}, },
template="baichuan2", template="baichuan2",
@ -274,27 +274,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ChineseLLaMA2-1.3B": { "Chinese-Llama-2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
}, },
"ChineseLLaMA2-7B": { "Chinese-Llama-2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
}, },
"ChineseLLaMA2-13B": { "Chinese-Llama-2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
}, },
"ChineseLLaMA2-1.3B-Chat": { "Chinese-Alpaca-2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
}, },
"ChineseLLaMA2-7B-Chat": { "Chinese-Alpaca-2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
}, },
"ChineseLLaMA2-13B-Chat": { "Chinese-Alpaca-2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
}, },
@ -450,25 +450,25 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeekCoder-6.7B-Base": { "DeepSeek-Coder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
}, },
"DeepSeekCoder-7B-Base": { "DeepSeek-Coder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
}, },
"DeepSeekCoder-33B-Base": { "DeepSeek-Coder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
}, },
"DeepSeekCoder-6.7B-Instruct": { "DeepSeek-Coder-6.7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
}, },
"DeepSeekCoder-7B-Instruct": { "DeepSeek-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
}, },
"DeepSeekCoder-33B-Instruct": { "DeepSeek-Coder-33B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
}, },
@ -477,6 +477,16 @@ register_model_group(
) )
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
},
},
template="exaone",
)
register_model_group( register_model_group(
models={ models={
"Falcon-7B": { "Falcon-7B": {
@ -550,10 +560,12 @@ register_model_group(
"Gemma-2-2B-Instruct": { "Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it", DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
}, },
"Gemma-2-9B-Instruct": { "Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
}, },
"Gemma-2-27B-Instruct": { "Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
@ -573,6 +585,7 @@ register_model_group(
"GLM-4-9B-Chat": { "GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
}, },
"GLM-4-9B-1M-Chat": { "GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
@ -583,6 +596,33 @@ register_model_group(
) )
register_model_group(
models={
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
},
"Index-1.9B-Base-Pure": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
},
},
template="index",
)
register_model_group( register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
@ -624,16 +664,10 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
}, },
},
template="intern2",
)
register_model_group(
models={
"InternLM2.5-1.8B": { "InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
}, },
"InternLM2.5-7B": { "InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b", DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
@ -642,22 +676,27 @@ register_model_group(
"InternLM2.5-20B": { "InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b", DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
}, },
"InternLM2.5-1.8B-Chat": { "InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
}, },
"InternLM2.5-7B-Chat": { "InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
}, },
"InternLM2.5-7B-1M-Chat": { "InternLM2.5-7B-1M-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
}, },
"InternLM2.5-20B-Chat": { "InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
}, },
}, },
template="intern2", template="intern2",
@ -686,19 +725,19 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA-7B": { "Llama-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
}, },
"LLaMA-13B": { "Llama-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b", DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b", DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
}, },
"LLaMA-30B": { "Llama-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b", DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b", DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
}, },
"LLaMA-65B": { "Llama-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b", DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b", DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
}, },
@ -708,27 +747,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA2-7B": { "Llama-2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
}, },
"LLaMA2-13B": { "Llama-2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
}, },
"LLaMA2-70B": { "Llama-2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
}, },
"LLaMA2-7B-Chat": { "Llama-2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
}, },
"LLaMA2-13B-Chat": { "Llama-2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
}, },
"LLaMA2-70B-Chat": { "Llama-2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
}, },
@ -739,60 +778,78 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA3-8B": { "Llama-3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
}, },
"LLaMA3-70B": { "Llama-3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
}, },
"LLaMA3-8B-Instruct": { "Llama-3-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
}, },
"LLaMA3-70B-Instruct": { "Llama-3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
}, },
"LLaMA3-8B-Chinese-Chat": { "Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
}, },
"LLaMA3-70B-Chinese-Chat": { "Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
}, },
}, "Llama-3.1-8B": {
template="llama3",
)
register_model_group(
models={
"LLaMA3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
}, },
"LLaMA3.1-70B": { "Llama-3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
}, },
"LLaMA3.1-405B": { "Llama-3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
}, },
"LLaMA3.1-8B-Instruct": { "Llama-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
}, },
"LLaMA3.1-70B-Instruct": { "Llama-3.1-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
}, },
"LLaMA3.1-405B-Instruct": { "Llama-3.1-405B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
}, },
"Llama-3.1-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
},
"Llama-3.1-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
},
"Llama-3.2-1B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
},
"Llama-3.2-3B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
},
"Llama-3.2-1B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
},
"Llama-3.2-3B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
},
}, },
template="llama3", template="llama3",
) )
@ -800,11 +857,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaVA1.5-7B-Chat": { "LLaVA-1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
}, },
"LLaVA1.5-13B-Chat": { "LLaVA-1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
}, },
}, },
template="llava", template="llava",
@ -812,6 +871,117 @@ register_model_group(
) )
register_model_group(
models={
"LLaVA-NeXT-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
},
"LLaVA-NeXT-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
},
},
template="llava_next",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Mistral-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
},
},
template="llava_next_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Llama3-8B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
},
},
template="llava_next_llama3",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
},
},
template="llava_next_yi",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-72B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
},
"LLaVA-NeXT-110B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
},
},
template="llava_next_qwen",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
},
"LLaVA-NeXT-Video-7B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
},
},
template="llava_next_video",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-32k-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
},
},
template="llava_next_video_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
},
"LLaVA-NeXT-Video-34B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
},
},
template="llava_next_video_yi",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
@ -832,6 +1002,7 @@ register_model_group(
"MiniCPM3-4B-Chat": { "MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B", DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
}, },
}, },
template="cpm3", template="cpm3",
@ -1005,27 +1176,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Phi3-4B-4k-Instruct": { "Phi-3-4B-4k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
}, },
"Phi3-4B-128k-Instruct": { "Phi-3-4B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
}, },
"Phi3-7B-8k-Instruct": { "Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
}, },
"Phi3-7B-128k-Instruct": { "Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
}, },
"Phi3-14B-8k-Instruct": { "Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
}, },
"Phi3-14B-128k-Instruct": { "Phi-3-14B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
}, },
@ -1034,6 +1205,18 @@ register_model_group(
) )
register_model_group(
models={
"Pixtral-12B-Chat": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {
@ -1068,35 +1251,35 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
}, },
"Qwen-1.8B-int8-Chat": { "Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
}, },
"Qwen-1.8B-int4-Chat": { "Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
}, },
"Qwen-7B-int8-Chat": { "Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
}, },
"Qwen-7B-int4-Chat": { "Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
}, },
"Qwen-14B-int8-Chat": { "Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
}, },
"Qwen-14B-int4-Chat": { "Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
}, },
"Qwen-72B-int8-Chat": { "Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
}, },
"Qwen-72B-int4-Chat": { "Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
}, },
@ -1179,75 +1362,75 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
}, },
"Qwen1.5-0.5B-int8-Chat": { "Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-0.5B-int4-Chat": { "Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
}, },
"Qwen1.5-1.8B-int8-Chat": { "Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-1.8B-int4-Chat": { "Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
}, },
"Qwen1.5-4B-int8-Chat": { "Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-4B-int4-Chat": { "Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
}, },
"Qwen1.5-7B-int8-Chat": { "Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-7B-int4-Chat": { "Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
}, },
"Qwen1.5-14B-int8-Chat": { "Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-14B-int4-Chat": { "Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
}, },
"Qwen1.5-32B-int4-Chat": { "Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
}, },
"Qwen1.5-72B-int8-Chat": { "Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-72B-int4-Chat": { "Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
}, },
"Qwen1.5-110B-int4-Chat": { "Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
}, },
"Qwen1.5-MoE-A2.7B-int4-Chat": { "Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
}, },
"Qwen1.5-Code-7B": { "CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
}, },
"Qwen1.5-Code-7B-Chat": { "CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
}, },
"Qwen1.5-Code-7B-int4-Chat": { "CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
}, },
@ -1281,14 +1464,17 @@ register_model_group(
"Qwen2-0.5B-Instruct": { "Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
}, },
"Qwen2-1.5B-Instruct": { "Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
}, },
"Qwen2-7B-Instruct": { "Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
}, },
"Qwen2-72B-Instruct": { "Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
@ -1568,51 +1754,53 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2VL-2B-Instruct": { "Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
}, },
"Qwen2VL-7B-Instruct": { "Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
}, },
"Qwen2VL-72B-Instruct": { "Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
}, },
"Qwen2VL-2B-Instruct-GPTQ-Int8": { "Qwen2-VL-2B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-2B-Instruct-GPTQ-Int4": { "Qwen2-VL-2B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-2B-Instruct-AWQ": { "Qwen2-VL-2B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
}, },
"Qwen2VL-7B-Instruct-GPTQ-Int8": { "Qwen2-VL-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-7B-Instruct-GPTQ-Int4": { "Qwen2-VL-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-7B-Instruct-AWQ": { "Qwen2-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
}, },
"Qwen2VL-72B-Instruct-GPTQ-Int8": { "Qwen2-VL-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-72B-Instruct-GPTQ-Int4": { "Qwen2-VL-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-72B-Instruct-AWQ": { "Qwen2-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
}, },
@ -1673,10 +1861,12 @@ register_model_group(
"TeleChat-7B-Chat": { "TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B", DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
}, },
"TeleChat-12B-Chat": { "TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
}, },
"TeleChat-12B-v2-Chat": { "TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
@ -1689,11 +1879,11 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Vicuna1.5-7B-Chat": { "Vicuna-v1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
}, },
"Vicuna1.5-13B-Chat": { "Vicuna-v1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
}, },
@ -1702,6 +1892,17 @@ register_model_group(
) )
register_model_group(
models={
"Video-LLaVA-7B-Chat": {
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
},
},
template="video_llava",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"XuanYuan-6B": { "XuanYuan-6B": {
@ -1712,7 +1913,7 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
}, },
"XuanYuan-2-70B": { "XuanYuan2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
}, },
@ -1724,31 +1925,31 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
}, },
"XuanYuan-2-70B-Chat": { "XuanYuan2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
}, },
"XuanYuan-6B-int8-Chat": { "XuanYuan-6B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
}, },
"XuanYuan-6B-int4-Chat": { "XuanYuan-6B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
}, },
"XuanYuan-70B-int8-Chat": { "XuanYuan-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
}, },
"XuanYuan-70B-int4-Chat": { "XuanYuan-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
}, },
"XuanYuan-2-70B-int8-Chat": { "XuanYuan2-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
}, },
"XuanYuan-2-70B-int4-Chat": { "XuanYuan2-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
}, },
@ -1853,19 +2054,19 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
}, },
"Yi-6B-int8-Chat": { "Yi-6B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
}, },
"Yi-6B-int4-Chat": { "Yi-6B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
}, },
"Yi-34B-int8-Chat": { "Yi-34B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
}, },
"Yi-34B-int4-Chat": { "Yi-34B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
}, },
@ -1884,6 +2085,7 @@ register_model_group(
"Yi-1.5-6B-Chat": { "Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
}, },
"Yi-1.5-9B-Chat": { "Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
@ -1916,10 +2118,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"YiVL-6B-Chat": { "Yi-VL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
}, },
"YiVL-34B-Chat": { "Yi-VL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
}, },
}, },

View File

@ -72,4 +72,4 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")

View File

@ -20,6 +20,7 @@ import os
import sys import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
formatter = logging.Formatter( self._formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
) )
self.setLevel(logging.INFO) self.setLevel(logging.INFO)
self.setFormatter(formatter)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG) self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log): if os.path.exists(self.running_log):
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
if record.name == "httpx": if record.name == "httpx":
return return
log_entry = self.format(record) log_entry = self._formatter.format(record)
self.thread_pool.submit(self._write_log, log_entry) self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None: def close(self) -> None:
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
"""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r""" r"""
Returns the default logging level. Returns the default logging level.
@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()] return logging._nameToLevel[env_level_str.upper()]
else: else:
raise ValueError("Unknown logging level: {}.".format(env_level_str)) raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level return _default_log_level
@ -84,7 +99,7 @@ def _get_library_name() -> str:
return __name__.split(".")[0] return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger": def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name()) return logging.getLogger(_get_library_name())
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
global _default_handler global _default_handler
with _thread_lock: with _thread_lock:
if _default_handler: if _default_handler: # already configured
return return
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
_default_handler = logging.StreamHandler(sys.stdout) _default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter) _default_handler.setFormatter(formatter)
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger": def get_logger(name: Optional[str] = None) -> "_Logger":
r""" r"""
Returns a logger with the specified name. It it not supposed to be accessed externally. Returns a logger with the specified name. It it not supposed to be accessed externally.
""" """
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
_configure_library_root_logger() _configure_library_root_logger()
return logging.getLogger(name) return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once

View File

@ -32,7 +32,7 @@ from transformers.utils import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from .logging import get_logger from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
@ -48,7 +48,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class AverageMeter: class AverageMeter:
@ -76,12 +76,12 @@ def check_dependencies() -> None:
r""" r"""
Checks the version of the required packages. Checks the version of the required packages.
""" """
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
@ -231,18 +231,35 @@ def torch_gc() -> None:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def try_download_model_from_ms(model_args: "ModelArguments") -> str: def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path): if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
return model_args.model_name_or_path return model_args.model_name_or_path
try: if use_modelscope():
from modelscope import snapshot_download require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) return snapshot_download(
except ImportError: model_args.model_name_or_path,
raise ImportError("Please install modelscope via `pip install modelscope -U`") revision=revision,
cache_dir=model_args.cache_dir,
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,
revision=model_args.model_revision,
cache_dir=model_args.cache_dir,
)
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]

View File

@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0") return _get_package_version("transformers") >= version.parse("4.43.0")
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@ -19,7 +19,7 @@ from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger from . import logging
from .packages import is_matplotlib_available from .packages import is_matplotlib_available
@ -28,7 +28,7 @@ if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]: def smooth(scalars: List[float]) -> List[float]:
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for key in keys: for key in keys:
@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics.append(data["log_history"][i][key]) metrics.append(data["log_history"][i][key])
if len(metrics) == 0: if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.") logger.warning_rank0(f"No metric {key} to plot.")
continue continue
plt.figure() plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary)) plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()

View File

@ -41,6 +41,10 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
image_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
)
cutoff_len: int = field( cutoff_len: int = field(
default=1024, default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
@ -111,7 +115,13 @@ class DataArguments:
) )
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the tokenized datasets."}, metadata={
"help": (
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
) )
def __post_init__(self): def __post_init__(self):
@ -123,6 +133,9 @@ class DataArguments:
self.dataset = split_arg(self.dataset) self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset) self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6: if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.") raise ValueError("Cannot specify `val_size` if `dataset` is None.")

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -267,6 +267,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
@ -308,20 +312,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": def copyfrom(cls, source: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() init_args, lazy_args = {}, {}
arg_dict.update(**kwargs) for attr in fields(source):
for attr in fields(cls): if attr.init:
if not attr.init: init_args[attr.name] = getattr(source, attr.name)
arg_dict.pop(attr.name) else:
lazy_args[attr.name] = getattr(source, attr.name)
new_arg = cls(**arg_dict) init_args.update(kwargs)
new_arg.compute_dtype = old_arg.compute_dtype result = cls(**init_args)
new_arg.device_map = old_arg.device_map for name, value in lazy_args.items():
new_arg.model_max_length = old_arg.model_max_length setattr(result, name, value)
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg return result

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
@ -104,7 +103,7 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer: if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
@ -123,7 +122,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0") require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing: if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.") logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and model_args.resize_vocab and model_args.resize_vocab
and finetuning_args.additional_target is None and finetuning_args.additional_target is None
): ):
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.") logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None: if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.") logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments # Post-process training arguments
if ( if (
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.") logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None training_args.resume_from_checkpoint = None
else: else:
can_resume_from_checkpoint = True can_resume_from_checkpoint = True
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None: if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if ( if (
finetuning_args.stage in ["rm", "ppo"] finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning( logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint training_args.resume_from_checkpoint
) )

View File

@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _setup_full_tuning( def _setup_full_tuning(
@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Full") logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules): if not any(forbidden_module in name for forbidden_module in forbidden_modules):
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Freeze") logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config") config = getattr(model.config, "text_config")
else: else:
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning( def _setup_lora_tuning(
@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
if is_trainable: if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
@ -190,7 +190,7 @@ def _setup_lora_tuning(
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
@ -236,11 +236,11 @@ def _setup_lora_tuning(
else: else:
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1: if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.") logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa" peft_kwargs["init_lora_weights"] = "pissa"
else: else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable: if not is_trainable:
pass pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam: elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else: else:
logger.info("Upcasting trainable params to float32.") logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
else: else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model return model

View File

@ -18,15 +18,15 @@ import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict): class TokenizerModule(TypedDict):
@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
skip_check_imports() skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args) model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return { return {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""
Loads pretrained tokenizer. Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
@ -82,33 +82,30 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right", padding_side="right",
**init_kwargs, **init_kwargs,
) )
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.new_special_tokens is not None: if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens), dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False, replace_additional_special_tokens=False,
) )
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab: if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer) patch_processor(processor, config, tokenizer, model_args)
setattr(processor, "image_seqlen", get_image_seqlen(config)) except Exception as e:
setattr(processor, "image_resolution", model_args.image_resolution) logger.debug(f"Processor was not found: {e}.")
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
except Exception:
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__: if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None processor = None
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
@ -135,6 +132,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None model = None
lazy_load = False lazy_load = False
@ -157,7 +155,7 @@ def load_model(
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
if model_args.train_from_scratch: if model_args.train_from_scratch:
model = load_class.from_config(config) model = load_class.from_config(config, trust_remote_code=True)
else: else:
model = load_class.from_pretrained(**init_kwargs) model = load_class.from_pretrained(**init_kwargs)
@ -182,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args) vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
@ -200,9 +198,9 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param trainable_params, all_param, 100 * trainable_params / all_param
) )
else: else:
param_stats = "all params: {:,}".format(all_param) param_stats = f"all params: {all_param:,}"
logger.info(param_stats) logger.info_rank0(param_stats)
if model_args.print_param_status: if model_args.print_param_status:
for name, param in model.named_parameters(): for name, param in model.named_parameters():

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(
@ -37,13 +37,16 @@ def configure_attn_implementation(
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") if model_args.flash_attn != "fa2":
model_args.flash_attn = "fa2" logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else: else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return
@ -53,18 +56,18 @@ def configure_attn_implementation(
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available(): if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.") logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.") logger.warning_rank0("FlashAttention-2 is not installed.")
return return
requested_attn_implementation = "flash_attention_2" requested_attn_implementation = "flash_attention_2"
else: else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation) setattr(config, "attn_implementation", requested_attn_implementation)
@ -79,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation = getattr(config, "_attn_implementation", None) attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2": if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.") logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa": elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.") logger.info_rank0("Using torch SDPA for faster training and inference.")
else: else:
logger.info("Using vanilla attention implementation.") logger.info_rank0("Using vanilla attention implementation.")

View File

@ -19,14 +19,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable: def get_unsloth_gradient_checkpointing_func() -> Callable:
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers. Only applies gradient checkpointing to trainable layers.
""" """
@wraps(gradient_checkpointing_func) @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing: if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32 (3) add the upcasting of the lm_head in fp32
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.") logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing: if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.") logger.warning_rank0("Current model does not support gradient checkpointing.")
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output: if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings() output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.") logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook) output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -23,10 +24,15 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel: if not is_trainable or not model_args.enable_liger_kernel:
return return
@ -48,8 +54,14 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else: else:
logger.warning("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
apply_liger_kernel() if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Liger kernel has been applied to the model.") logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
Cache, Cache,
LlamaAttention, LlamaAttention,
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43 from ...extras.packages import is_transformers_version_greater_than_4_43
@ -44,7 +44,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from: # Modified from:
@ -86,7 +86,7 @@ def llama_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn: if not is_trainable or not model_args.shift_attn:
return return
logger = get_logger(__name__) logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch() _apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.") logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else: else:
logger.warning("Current model does not support shift short attention.") logger.warning_rank0("Current model does not support shift short attention.")

View File

@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer") forbidden_modules.add("output_layer")
elif model_type == "internlm2": elif model_type == "internlm2":
forbidden_modules.add("output") forbidden_modules.add("output")
elif model_type in ["llava", "paligemma"]: elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
forbidden_modules.add("merger") forbidden_modules.add("merger")
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names) return list(module_names)
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0: if num_layers % num_layer_trainable != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
) )
stride = num_layers // num_layer_trainable stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = [] module_names = []
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any( if any(target_module in name for target_module in target_modules) and any(
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
): ):
module_names.append(name) module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names return module_names

View File

@ -43,8 +43,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43 from ...extras.packages import is_transformers_version_greater_than_4_43
@ -54,7 +54,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils
@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type) _patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.") logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else: else:
raise ValueError("Current model does not support block diagonal attention.") raise ValueError("Current model does not support block diagonal attention.")

View File

@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
@ -109,7 +109,7 @@ def configure_quantization(
""" """
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2 quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
@ -149,7 +149,7 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
@ -179,7 +179,7 @@ def configure_quantization(
else: else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value: elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig( init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance ) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value: elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8: if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.") raise ValueError("EETQ only accepts 8-bit quantization.")
@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq") require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig() init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@ -19,7 +19,7 @@
import math import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return return
if not hasattr(config, "rope_scaling"): if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
if model_args.model_max_length is not None: if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic": if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning( logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. " "Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653" "See: https://github.com/huggingface/transformers/pull/24653"
) )
current_max_length = getattr(config, "max_position_embeddings", None) current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length: if current_max_length and model_args.model_max_length > current_max_length:
logger.info( logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length) setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else: else:
logger.warning("Input length is smaller than max length. Consider increase input length.") logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0 scaling_factor = 1.0
else: else:
scaling_factor = 2.0 scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info( logger.info_rank0(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
) )

View File

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger from ...extras import logging
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None model = None
model_args.use_unsloth = False model_args.use_unsloth = False

View File

@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import torch import torch
from transformers.utils import cached_file from transformers.utils import cached_file
from ...extras import logging
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err: except Exception as err:
err_text = str(err) err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.") logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
return None return None

View File

@ -18,11 +18,11 @@
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch import torch
import transformers
import transformers.models import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -31,8 +31,8 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
transformers_logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module): class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "paligemma"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else: else:
return return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
@ -108,11 +108,18 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
Patches VLMs before loading them. Patches VLMs before loading them.
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_type == "llava": # required for ds zero3 and valuehead models if model_type in [
"llava",
"llava_next",
"llava_next_video",
"paligemma",
"pixtral",
"video_llava",
]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
@ -122,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
forbidden_modules = set() forbidden_modules = set()
if model_type in ["llava", "paligemma"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower") forbidden_modules.add("vision_tower")
@ -150,12 +157,28 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen += 1 image_seqlen += 1
elif model_type == "paligemma": elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens image_seqlen = config.vision_config.num_image_tokens
elif model_type == "qwen2_vl": # variable length else:
image_seqlen = -1 image_seqlen = -1
return image_seqlen return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
return vision_feature_select_strategy
def patch_target_modules( def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
@ -164,7 +187,7 @@ def patch_target_modules(
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "paligemma"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
@ -173,5 +196,7 @@ def patch_target_modules(
else: else:
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else: else:
return target_modules return target_modules

View File

@ -22,29 +22,34 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras import logging
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model from .model_utils.visual import (
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@ -52,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_processor(
processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
) -> None:
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
def patch_config( def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@ -71,7 +92,6 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
@ -80,7 +100,7 @@ def patch_config(
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info_rank0("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
@ -90,6 +110,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
@ -142,7 +165,7 @@ def patch_model(
try: try:
model.add_model_tags(["llama-factory"]) model.add_model_tags(["llama-factory"])
except Exception: except Exception:
logger.warning("Cannot properly tag the model.") logger.warning_rank0("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import json import json
import logging
import os import os
import signal import signal
import sys import sys
@ -34,8 +33,8 @@ from transformers.utils import (
) )
from typing_extensions import override from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory from ..extras.misc import get_peak_memory
@ -48,7 +47,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint( def fix_valuehead_checkpoint(
@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir)) logger.info_rank0(f"Value head model saved at: {output_dir}")
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
) )
@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
getattr(self.processor, "image_processor").save_pretrained(output_dir) self.processor.save_pretrained(output_dir)
@override @override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save: if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir) self.processor.save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
# 1. save a pissa backup with init_lora_weights: True # 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa # 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True # 3. load the pissa backup with init_lora_weights: True
@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler) logging.add_handler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None: def _set_abort(self, signum, frame) -> None:
@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and args.overwrite_output_dir
): ):
logger.warning("Previous trainer log in this folder will be deleted.") logger.warning_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override @override
@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None), loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies", None), accuracy=state.log_history[-1].get("rewards/accuracies"),
learning_rate=state.log_history[-1].get("learning_rate", None), lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory() vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
logger.info( log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( for extra_key in ("reward", "accuracy", "throughput"):
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") if logs.get(extra_key):
) log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
)
logger.info_rank0("{" + log_str + "}")
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self.callback_handler.add_callback(PissaConvertCallback) self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
elif self.loss_type == "simpo": elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps) losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else: else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type)) raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo": if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu() metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
return losses.mean(), metrics return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 10: # pad to for all reduce
for i in range(10 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = { model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch[f"{prefix}attention_mask"],
} }
if "{}token_type_ids".format(prefix) in batch: if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length return logits, logps, logps / valid_length
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]] chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]] rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch ref_model, batch
) )
@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
metrics = {} metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( (
self.concatenated_forward(model, batch) policy_chosen_logps,
) policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch model, batch
) )
@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss = -policy_chosen_logps_avg sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_chosen = len(chosen_rewards)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = float(num_chosen)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item() if num_rejected > 0:
all_num_rejected = self.accelerator.gather(num_rejected).sum().item() metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
if all_num_chosen > 0: metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() metrics["count/rejected"] = float(num_rejected)
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected
metrics["kl"] = kl.item() metrics["kl"] = kl.item()
return losses, metrics return losses, metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 9: # pad to for all reduce
for i in range(9 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
del metric_dict[f"{key}/{split}_sum"]
del metric_dict[f"count/{split}"]
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
for key, metric in metric_dict.items(): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -81,7 +81,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
device = v_head_layer.weight.device device = v_head_layer.weight.device
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:

View File

@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -58,7 +58,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer): class CustomPPOTrainer(PPOTrainer, Trainer):
@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
] ]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with is not None: if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.") logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None ppo_config.log_with = None
# Create optimizer and scheduler # Create optimizer and scheduler
@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
) )
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
self.amp_context = torch.autocast(self.current_device.type) self.amp_context = torch.autocast(self.current_device.type)
warnings.simplefilter("ignore") # remove gc warnings on ref model warnings.simplefilter("ignore") # remove gc warnings on ref model
@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero() self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero(): logger.info_rank0("***** Running training *****")
logger.info("***** Running training *****") logger.info_rank0(f" Num examples = {num_examples:,}")
logger.info(" Num examples = {:,}".format(num_examples)) logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info(" Num Epochs = {:,}".format(num_train_epochs)) logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size)) logger.info_rank0(
logger.info( " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( total_train_batch_size
total_train_batch_size
)
) )
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps)) )
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs)) logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info(" Total training steps = {:,}".format(max_steps)) logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0])) logger.info_rank0(f" Total training steps = {max_steps:,}")
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
loss_meter = AverageMeter() loss_meter = AverageMeter()
@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards) self.log_stats(stats, batch, rewards)
except Exception: except Exception:
logger.warning("Failed to save stats due to unknown errors.") logger.warning_rank0("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.callback_handler.on_step_end(self.args, self.state, self.control) self.callback_handler.on_step_end(self.args, self.state, self.control)
@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if (step + 1) % self.args.save_steps == 0: # save checkpoint if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model( self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
) )
self.callback_handler.on_save(self.args, self.state, self.control) self.callback_handler.on_save(self.args, self.state, self.control)
@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
except ValueError: except ValueError:
logger.warning( logger.warning_rank0(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -30,9 +30,6 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
r""" r"""
Inherits Trainer for custom optimizer. Inherits Trainer for custom optimizer.
@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss

View File

@ -24,7 +24,8 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -36,7 +37,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer): class PairwiseTrainer(Trainer):
@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs: if return_outputs:
return loss, (loss, chosen_scores, rejected_scores) return loss, (loss, chosen_scores, rejected_scores)
else: else:
@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
return return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@ -25,8 +25,9 @@ import torch
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from typing_extensions import override from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -39,7 +40,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override @override
def prediction_step( def prediction_step(
self, self,
@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where( labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id

View File

@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys()) assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys(): for name in state_dict_a.keys():
if any(key in name for key in diff_keys): if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
else: else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
@ -80,18 +80,17 @@ def load_reference_model(
is_trainable: bool = False, is_trainable: bool = False,
add_valuehead: bool = False, add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]: ) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=get_current_device() model_path, torch_dtype=torch.float16, device_map=current_device
) )
if not is_trainable: if not is_trainable:
model.v_head = model.v_head.to(torch.float16) model.v_head = model.v_head.to(torch.float16)
return model return model
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
model_path, torch_dtype=torch.float16, device_map=get_current_device()
)
if use_lora or use_pissa: if use_lora or use_pissa:
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
return dataset_module["train_dataset"] return dataset_module["train_dataset"]
def patch_valuehead_model(): def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None: def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False) self.v_head.load_state_dict(state_dict, strict=False)

View File

@ -28,8 +28,8 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
@ -46,7 +46,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer): class DummyOptimizer(torch.optim.Optimizer):
@ -116,7 +116,7 @@ def create_ref_model(
ref_model = load_model( ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from {}".format(finetuning_args.ref_model)) logger.info_rank0(f"Created reference model from {finetuning_args.ref_model}")
else: else:
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
@ -127,7 +127,7 @@ def create_ref_model(
ref_model = load_model( ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
) )
logger.info("Created reference model from the model itself.") logger.info_rank0("Created reference model from the model itself.")
return ref_model return ref_model
@ -140,7 +140,7 @@ def create_reward_model(
""" """
if finetuning_args.reward_model_type == "api": if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url." assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model)) logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
return finetuning_args.reward_model return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora": elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
@ -157,7 +157,7 @@ def create_reward_model(
model.register_buffer( model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
) )
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) logger.info_rank0(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
return None return None
else: else:
reward_model_args = ModelArguments.copyfrom( reward_model_args = ModelArguments.copyfrom(
@ -171,8 +171,8 @@ def create_reward_model(
reward_model = load_model( reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
) )
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) logger.info_rank0(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") logger.warning_rank0("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model return reward_model
@ -231,7 +231,7 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor": elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor optim_class = GaLoreAdafactor
else: else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) raise NotImplementedError(f"Unknow optim: {training_args.optim}")
if finetuning_args.galore_layerwise: if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
@ -265,7 +265,7 @@ def _create_galore_optimizer(
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer return optimizer
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay), dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) logger.info_rank0(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
return optimizer return optimizer
@ -343,7 +343,7 @@ def _create_badam_optimizer(
verbose=finetuning_args.badam_verbose, verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=is_deepspeed_zero3_enabled(), ds_zero3_enabled=is_deepspeed_zero3_enabled(),
) )
logger.info( logger.info_rank0(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_interval} steps, " f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}" f"default start block is {finetuning_args.badam_start_block}"
@ -362,7 +362,7 @@ def _create_badam_optimizer(
include_embedding=False, include_embedding=False,
**optim_kwargs, **optim_kwargs,
) )
logger.info( logger.info_rank0(
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}" f"mask mode is {finetuning_args.badam_mask_mode}"
) )
@ -391,7 +391,7 @@ def _create_adam_mini_optimizer(
n_heads=num_q_head, n_heads=num_q_head,
n_kv_heads=num_kv_head, n_kv_heads=num_kv_head,
) )
logger.info("Using Adam-mini optimizer.") logger.info_rank0("Using Adam-mini optimizer.")
return optimizer return optimizer

View File

@ -20,8 +20,8 @@ import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback from .callbacks import LogCallback
@ -37,7 +37,7 @@ if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
elif finetuning_args.stage == "kto": elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks) run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else: else:
raise ValueError("Unknown task: {}.".format(finetuning_args.stage)) raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
@ -91,18 +91,18 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", output_dtype) setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype) model = model.to(output_dtype)
logger.info("Convert model dtype to: {}.".format(output_dtype)) logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format), safe_serialization=(not model_args.export_legacy_format),
) )
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
model.push_to_hub( model.push_to_hub(
model_args.export_hub_model_id, model_args.export_hub_model_id,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
max_shard_size="{}GB".format(model_args.export_size), max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format), safe_serialization=(not model_args.export_legacy_format),
) )
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
) )
logger.info("Copied valuehead to {}.".format(model_args.export_dir)) logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy( shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
) )
logger.info("Copied valuehead to {}.".format(model_args.export_dir)) logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
@ -133,11 +133,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if processor is not None: if processor is not None:
getattr(processor, "image_processor").save_pretrained(model_args.export_dir) processor.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
getattr(processor, "image_processor").push_to_hub( processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
model_args.export_hub_model_id, token=model_args.hf_hub_token
)
except Exception: except Exception as e:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")

View File

@ -141,7 +141,14 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = "" chatbot[-1][1] = ""
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages,
system,
tools,
images=[image] if image else None,
videos=[video] if video else None,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
): ):
response += new_text response += new_text
if tools: if tools:

View File

@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Tuple
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
from ..extras import logging
from ..extras.constants import ( from ..extras.constants import (
CHECKPOINT_NAMES, CHECKPOINT_NAMES,
DATA_CONFIG, DATA_CONFIG,
@ -30,8 +31,7 @@ from ..extras.constants import (
VISION_MODELS, VISION_MODELS,
DownloadSource, DownloadSource,
) )
from ..extras.logging import get_logger from ..extras.misc import use_modelscope, use_openmind
from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
@ -39,7 +39,7 @@ if is_gradio_available():
import gradio as gr import gradio as gr
logger = get_logger(__name__) logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -56,7 +56,7 @@ def get_save_dir(*paths: str) -> os.PathLike:
Gets the path to saved model checkpoints. Gets the path to saved model checkpoints.
""" """
if os.path.sep in paths[-1]: if os.path.sep in paths[-1]:
logger.warning("Found complex path, some features may be not available.") logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1] return paths[-1]
paths = (path.replace(" ", "").strip() for path in paths) paths = (path.replace(" ", "").strip() for path in paths)
@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
Loads user config if exists. Loads user config if exists.
""" """
try: try:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
except Exception: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -109,19 +109,19 @@ def get_model_path(model_name: str) -> str:
use_modelscope() use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE) and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT) and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace path ): # replace hf path with ms path
model_path = path_dict.get(DownloadSource.MODELSCOPE) model_path = path_dict.get(DownloadSource.MODELSCOPE)
if (
use_openmind()
and path_dict.get(DownloadSource.OPENMIND)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace hf path with om path
model_path = path_dict.get(DownloadSource.OPENMIND)
return model_path return model_path
def get_prefix(model_name: str) -> str:
r"""
Gets the prefix of the model name to obtain the model family.
"""
return model_name.split("-")[0]
def get_model_info(model_name: str) -> Tuple[str, str]: def get_model_info(model_name: str) -> Tuple[str, str]:
r""" r"""
Gets the necessary information of this model. Gets the necessary information of this model.
@ -137,21 +137,14 @@ def get_template(model_name: str) -> str:
r""" r"""
Gets the template name if the model is a chat model. Gets the template name if the model is a chat model.
""" """
if ( return DEFAULT_TEMPLATE.get(model_name, "default")
model_name
and any(suffix in model_name for suffix in ("-Chat", "-Instruct"))
and get_prefix(model_name) in DEFAULT_TEMPLATE
):
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"
def get_visual(model_name: str) -> bool: def get_visual(model_name: str) -> bool:
r""" r"""
Judges if the model is a vision language model. Judges if the model is a vision language model.
""" """
return get_prefix(model_name) in VISION_MODELS return model_name in VISION_MODELS
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
@ -179,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
Loads dataset_info.json. Loads dataset_info.json.
""" """
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir)) logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {} return {}
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
return json.load(f) return json.load(f)
except Exception as err: except Exception as err:
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err))) logger.warning_rank0(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
return {} return {}

View File

@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception: except Exception:
return gr.Button(interactive=False) return gr.Button(interactive=False)
@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
def _load_data_file(file_path: str) -> List[Any]: def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"): if file_path.endswith(".json"):
return json.load(f) return json.load(f)
elif file_path.endswith(".jsonl"): elif file_path.endswith(".jsonl"):
@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]) data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])

Some files were not shown because too many files have changed in this diff Show More