Merge branch 'main' into main

Former-commit-id: 5f14910910154ba569435e7e68acbd6c30f79e80
This commit is contained in:
hoshi-hiyouga 2024-11-02 21:20:27 +08:00 committed by GitHub
commit d99e164cad
147 changed files with 3087 additions and 1833 deletions

View File

@ -7,6 +7,8 @@ data
docker docker
saves saves
hf_cache hf_cache
ms_cache
om_cache
output output
.dockerignore .dockerignore
.gitattributes .gitattributes

View File

@ -1,32 +1,35 @@
# Note: actually we do not support .env, just for reference # Note: actually we do not support .env, just for reference
# api # api
API_HOST=0.0.0.0 API_HOST=
API_PORT=8000 API_PORT=
API_KEY= API_KEY=
API_MODEL_NAME=gpt-3.5-turbo API_MODEL_NAME=
FASTAPI_ROOT_PATH= FASTAPI_ROOT_PATH=
MAX_CONCURRENT=
# general # general
DISABLE_VERSION_CHECK= DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS= FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY= LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB= USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
RECORD_VRAM= RECORD_VRAM=
# torchrun # torchrun
FORCE_TORCHRUN= FORCE_TORCHRUN=
MASTER_ADDR= MASTER_ADDR=
MASTER_PORT= MASTER_PORT=
NNODES= NNODES=
RANK= NODE_RANK=
NPROC_PER_NODE= NPROC_PER_NODE=
# wandb # wandb
WANDB_DISABLED= WANDB_DISABLED=
WANDB_PROJECT=huggingface WANDB_PROJECT=
WANDB_API_KEY= WANDB_API_KEY=
# gradio ui # gradio ui
GRADIO_SHARE=False GRADIO_SHARE=
GRADIO_SERVER_NAME=0.0.0.0 GRADIO_SERVER_NAME=
GRADIO_SERVER_PORT= GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH= GRADIO_ROOT_PATH=
GRADIO_IPV6=
# setup # setup
ENABLE_SHORT_CONSOLE=1 ENABLE_SHORT_CONSOLE=1
# reserved (do not use) # reserved (do not use)

View File

@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
### Style guide ### Style guide
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details. LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
### Create a Pull Request
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
git clone git@github.com:[username]/LLaMA-Factory.git
cd LLaMA-Factory
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
```
3. Create a new branch to hold your development changes:
```bash
git checkout -b dev_your_branch
```
4. Set up a development environment by running the following command in a virtual environment:
```bash
pip install -e ".[dev]"
```
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
5. Check code before commit:
```bash
make commit
make style && make quality
make test
```
6. Submit changes:
```bash
git add .
git commit -m "commit message"
git fetch upstream
git rebase upstream/main
git push -u origin dev_your_branch
```
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).

View File

@ -26,15 +26,15 @@ jobs:
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.8"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install build python -m pip install build
- name: Build package - name: Build package
run: | run: |
python -m build python -m build
- name: Publish package - name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1

View File

@ -22,7 +22,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: python-version:
- "3.8" - "3.8" # TODO: remove py38 in next transformers release
- "3.9" - "3.9"
- "3.10" - "3.10"
- "3.11" - "3.11"
@ -54,7 +54,6 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install git+https://github.com/huggingface/transformers.git
python -m pip install ".[torch,dev]" python -m pip install ".[torch,dev]"
- name: Check quality - name: Check quality

1
.gitignore vendored
View File

@ -162,6 +162,7 @@ cython_debug/
# custom .gitignore # custom .gitignore
ms_cache/ ms_cache/
hf_cache/ hf_cache/
om_cache/
cache/ cache/
config/ config/
saves/ saves/

28
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,28 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-ast
- id: check-added-large-files
args: ['--maxkb=25000']
- id: check-merge-conflict
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/asottile/pyupgrade
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@ -1,7 +1,14 @@
.PHONY: quality style test .PHONY: build commit quality style test
check_dirs := scripts src tests setup.py check_dirs := scripts src tests setup.py
build:
pip install build && python -m build
commit:
pre-commit install
pre-commit run --all-files
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
ruff format --check $(check_dirs) ruff format --check $(check_dirs)
@ -11,4 +18,4 @@ style:
ruff format $(check_dirs) ruff format $(check_dirs)
test: test:
CUDA_VISIBLE_DEVICES= pytest tests/ CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/

104
README.md
View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-91-green)](#projects-using-llama-factory) [![Citation](https://img.shields.io/badge/citation-92-green)](#projects-using-llama-factory)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -26,10 +26,17 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
Choose your path: Choose your path:
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **Local machine**: Please refer to [usage](#getting-started) - **Local machine**: Please refer to [usage](#getting-started)
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/ - **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
Recent activities:
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
## Table of Contents ## Table of Contents
- [Features](#features) - [Features](#features)
@ -72,6 +79,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models. [24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR. [24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
@ -130,7 +139,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
@ -162,36 +171,39 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi-small | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen | | [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@ -360,7 +372,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
> [!TIP] > [!TIP]
> Use `pip install --no-deps -e .` to resolve package conflicts. > Use `pip install --no-deps -e .` to resolve package conflicts.
@ -412,7 +424,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
### Data Preparation ### Data Preparation
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk. Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
> [!NOTE] > [!NOTE]
> Please update `data/dataset_info.json` to use your custom dataset. > Please update `data/dataset_info.json` to use your custom dataset.
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v ./saves:/app/saves \ -v ./saves:/app/saves \
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory. - `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users. - `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI. - `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine. - `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP] > [!TIP]
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document. > Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
>
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
### Download from ModelScope Hub ### Download from ModelScope Hub
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`. Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
### Download from Modelers Hub
You can also use Modelers Hub to download models and datasets.
```bash
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
```
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
### Use W&B Logger ### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files. To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
@ -684,11 +712,13 @@ If you have a project that should be incorporated, please contact via email or c
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B. 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods. 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B. 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models. 1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX. 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@ -696,7 +726,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@ -4,7 +4,7 @@
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE) [![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Factory)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/)
[![Citation](https://img.shields.io/badge/citation-91-green)](#使用了-llama-factory-的项目) [![Citation](https://img.shields.io/badge/citation-92-green)](#使用了-llama-factory-的项目)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls)
[![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK)
[![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
@ -26,11 +26,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
选择你的打开方式: 选择你的打开方式:
- **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **Colab**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory - **PAI-DSW**[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
- **本地机器**:请见[如何使用](#如何使用) - **本地机器**:请见[如何使用](#如何使用)
- **入门教程**https://zhuanlan.zhihu.com/p/695287607 - **入门教程**https://zhuanlan.zhihu.com/p/695287607
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/ - **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
近期活动:
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
## 目录 ## 目录
- [项目特色](#项目特色) - [项目特色](#项目特色)
@ -73,6 +80,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。 [24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
@ -163,35 +172,38 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| ----------------------------------------------------------------- | -------------------------------- | --------- | | ----------------------------------------------------------------- | -------------------------------- | ---------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@ -360,7 +372,7 @@ cd LLaMA-Factory
pip install -e ".[torch,metrics]" pip install -e ".[torch,metrics]"
``` ```
可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality 可选的额外依赖项torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
> [!TIP] > [!TIP]
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
@ -412,7 +424,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
### 数据准备 ### 数据准备
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。 关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
> [!NOTE] > [!NOTE]
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
docker run -dit --gpus=all \ docker run -dit --gpus=all \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-p 7860:7860 \ -p 7860:7860 \
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v /usr/local/dcmi:/usr/local/dcmi \ -v /usr/local/dcmi:/usr/local/dcmi \
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
docker run -dit \ docker run -dit \
-v ./hf_cache:/root/.cache/huggingface \ -v ./hf_cache:/root/.cache/huggingface \
-v ./ms_cache:/root/.cache/modelscope \ -v ./ms_cache:/root/.cache/modelscope \
-v ./om_cache:/root/.cache/openmind \
-v ./data:/app/data \ -v ./data:/app/data \
-v ./output:/app/output \ -v ./output:/app/output \
-v ./saves:/app/saves \ -v ./saves:/app/saves \
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。 - `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。 - `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
- `data`:宿主机中存放数据集的文件夹路径。 - `data`:宿主机中存放数据集的文件夹路径。
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。 - `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
> [!TIP] > [!TIP]
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。 > API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
>
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
### 从魔搭社区下载 ### 从魔搭社区下载
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct` `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`
### 从魔乐社区下载
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
```bash
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
```
`model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`
### 使用 W&B 面板 ### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。 若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
@ -684,11 +712,12 @@ run_name: test_run # 可选
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。 1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。 1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**MBTI性格大模型项目根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt) 1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。 1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。 1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。 1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调. 1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
</details> </details>
@ -696,7 +725,7 @@ run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 29 KiB

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 149 KiB

After

Width:  |  Height:  |  Size: 132 KiB

View File

@ -4999,4 +4999,4 @@
"input": "Time waits for no one.", "input": "Time waits for no one.",
"output": "No one can stop time from moving forward." "output": "No one can stop time from moving forward."
} }
] ]

View File

@ -4999,4 +4999,4 @@
"input": "", "input": "",
"output": "安第斯山脉位于南美洲横跨七个国家包括委内瑞拉哥伦比亚厄瓜多尔秘鲁玻利维亚智利和阿根廷。安第斯山脉是世界上最长的山脉之一全长约7,000千米4,350英里其山脉沿着南美洲西海岸蜿蜒延伸平均海拔约为4,000米13,000英尺。在其南部安第斯山脉宽度达到700千米430英里在其北部宽度约为500千米310英里。" "output": "安第斯山脉位于南美洲横跨七个国家包括委内瑞拉哥伦比亚厄瓜多尔秘鲁玻利维亚智利和阿根廷。安第斯山脉是世界上最长的山脉之一全长约7,000千米4,350英里其山脉沿着南美洲西海岸蜿蜒延伸平均海拔约为4,000米13,000英尺。在其南部安第斯山脉宽度达到700千米430英里在其北部宽度约为500千米310英里。"
} }
] ]

View File

@ -17,9 +17,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
_LICENSE = "gpl-3.0" _LICENSE = "gpl-3.0"
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
class BelleMultiturn(datasets.GeneratorBasedBuilder): class BelleMultiturn(datasets.GeneratorBasedBuilder):
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})] return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str): def _generate_examples(self, filepath: str):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = [] conversations = []

File diff suppressed because one or more lines are too long

View File

@ -54,7 +54,8 @@
}, },
"alpaca_en": { "alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en", "hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en" "ms_hub_url": "llamafactory/alpaca_en",
"om_hub_url": "HaM/alpaca_en"
}, },
"alpaca_zh": { "alpaca_zh": {
"hf_hub_url": "llamafactory/alpaca_zh", "hf_hub_url": "llamafactory/alpaca_zh",
@ -66,7 +67,8 @@
}, },
"alpaca_gpt4_zh": { "alpaca_gpt4_zh": {
"hf_hub_url": "llamafactory/alpaca_gpt4_zh", "hf_hub_url": "llamafactory/alpaca_gpt4_zh",
"ms_hub_url": "llamafactory/alpaca_gpt4_zh" "ms_hub_url": "llamafactory/alpaca_gpt4_zh",
"om_hub_url": "State_Cloud/alpaca-gpt4-data-zh"
}, },
"glaive_toolcall_en": { "glaive_toolcall_en": {
"hf_hub_url": "llamafactory/glaive_toolcall_en", "hf_hub_url": "llamafactory/glaive_toolcall_en",
@ -623,4 +625,4 @@
}, },
"folder": "python" "folder": "python"
} }
} }

View File

@ -7223,4 +7223,4 @@
"value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write." "value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write."
} }
} }
] ]

View File

@ -5055,4 +5055,4 @@
"value": "C. 参与讨论" "value": "C. 参与讨论"
} }
} }
] ]

View File

@ -9155,4 +9155,4 @@
], ],
"tools": "[]" "tools": "[]"
} }
] ]

View File

@ -9019,4 +9019,4 @@
], ],
"tools": "[]" "tools": "[]"
} }
] ]

View File

@ -8,9 +8,9 @@ import datasets
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") _HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
_DESCRIPTION = "Human preference data about helpfulness and harmlessness." _DESCRIPTION = "Human preference data about helpfulness and harmlessness."
_CITATION = "" _CITATION = ""
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
_LICENSE = "mit" _LICENSE = "mit"
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT) _URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
_URLS = { _URLS = {
"train": [ "train": [
_URL + "harmless-base/train.jsonl.gz", _URL + "harmless-base/train.jsonl.gz",
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
key = 0 key = 0
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
data = json.loads(row) data = json.loads(row)
chosen = data["chosen"] chosen = data["chosen"]

View File

@ -454,4 +454,4 @@
"input": "", "input": "",
"output": "抱歉,我不是 OpenAI 开发的 ChatGPT我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。" "output": "抱歉,我不是 OpenAI 开发的 ChatGPT我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。"
} }
] ]

View File

@ -5395,4 +5395,4 @@
], ],
"label": false "label": false
} }
] ]

View File

@ -137,4 +137,4 @@
"mllm_demo_data/3.jpg" "mllm_demo_data/3.jpg"
] ]
} }
] ]

View File

@ -44,4 +44,4 @@
"mllm_demo_data/3.mp4" "mllm_demo_data/3.mp4"
] ]
} }
] ]

View File

@ -20,9 +20,9 @@ _CITATION = """\
} }
""" """
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT) _HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
_LICENSE = "cc-by-nc-4.0" _LICENSE = "cc-by-nc-4.0"
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT) _BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
class UltraChat(datasets.GeneratorBasedBuilder): class UltraChat(datasets.GeneratorBasedBuilder):
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
def _generate_examples(self, filepaths: List[str]): def _generate_examples(self, filepaths: List[str]):
for filepath in filepaths: for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, encoding="utf-8") as f:
for row in f: for row in f:
try: try:
data = json.loads(row) data = json.loads(row)

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
# Use the NVIDIA official image with PyTorch 2.3.0 # Default use the NVIDIA official image with PyTorch 2.3.0
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
FROM nvcr.io/nvidia/pytorch:24.02-py3 ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
FROM ${BASE_IMAGE}
# Define environments # Define environments
ENV MAX_JOBS=4 ENV MAX_JOBS=4
@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG INSTALL_EETQ=false
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory # Set the working directory
@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
if [ "$INSTALL_EETQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]" pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention # Rebuild flash attention

View File

@ -8,11 +8,15 @@ services:
INSTALL_VLLM: false INSTALL_VLLM: false
INSTALL_DEEPSPEED: false INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
INSTALL_EETQ: false
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
ports: ports:

View File

@ -10,6 +10,7 @@ services:
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
- /usr/local/dcmi:/usr/local/dcmi - /usr/local/dcmi:/usr/local/dcmi

View File

@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
ARG INSTALL_VLLM=false ARG INSTALL_VLLM=false
ARG INSTALL_DEEPSPEED=false ARG INSTALL_DEEPSPEED=false
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG INSTALL_LIGER_KERNEL=false
ARG INSTALL_HQQ=false
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
# Set the working directory # Set the working directory
@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \ if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
fi; \
if [ "$INSTALL_HQQ" == "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
fi; \
pip install -e ".[$EXTRA_PACKAGES]" pip install -e ".[$EXTRA_PACKAGES]"
# Rebuild flash attention # Rebuild flash attention

View File

@ -8,11 +8,14 @@ services:
INSTALL_VLLM: false INSTALL_VLLM: false
INSTALL_DEEPSPEED: false INSTALL_DEEPSPEED: false
INSTALL_FLASHATTN: false INSTALL_FLASHATTN: false
INSTALL_LIGER_KERNEL: false
INSTALL_HQQ: false
PIP_INDEX: https://pypi.org/simple PIP_INDEX: https://pypi.org/simple
container_name: llamafactory container_name: llamafactory
volumes: volumes:
- ../../hf_cache:/root/.cache/huggingface - ../../hf_cache:/root/.cache/huggingface
- ../../ms_cache:/root/.cache/modelscope - ../../ms_cache:/root/.cache/modelscope
- ../../om_cache:/root/.cache/openmind
- ../../data:/app/data - ../../data:/app/data
- ../../output:/app/output - ../../output:/app/output
- ../../saves:/app/saves - ../../saves:/app/saves

View File

@ -207,4 +207,4 @@
"name": "兽医学", "name": "兽医学",
"category": "STEM" "category": "STEM"
} }
} }

View File

@ -267,4 +267,4 @@
"name": "世界宗教", "name": "世界宗教",
"category": "Humanities" "category": "Humanities"
} }
} }

View File

@ -227,4 +227,4 @@
"name": "world religions", "name": "world religions",
"category": "Humanities" "category": "Humanities"
} }
} }

View File

@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
df = pd.read_csv(filepath, header=None) df = pd.read_csv(filepath, header=None)
df.columns = ["question", "A", "B", "C", "D", "answer"] df.columns = ["question", "A", "B", "C", "D", "answer"]
for i, instance in enumerate(df.to_dict(orient="records")): yield from enumerate(df.to_dict(orient="records"))
yield i, instance

View File

@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)

View File

@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
#### 多机指令监督微调 #### 多机指令监督微调
```bash ```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存

View File

@ -25,4 +25,4 @@
"contiguous_gradients": true, "contiguous_gradients": true,
"round_robin_gradients": true "round_robin_gradients": true
} }
} }

View File

@ -25,4 +25,4 @@
"contiguous_gradients": true, "contiguous_gradients": true,
"round_robin_gradients": true "round_robin_gradients": true
} }
} }

View File

@ -29,4 +29,4 @@
"contiguous_gradients": true, "contiguous_gradients": true,
"round_robin_gradients": true "round_robin_gradients": true
} }
} }

View File

@ -27,4 +27,4 @@
"stage3_max_reuse_distance": 1e9, "stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true "stage3_gather_16bit_weights_on_model_save": true
} }
} }

View File

@ -35,4 +35,4 @@
"stage3_max_reuse_distance": 1e9, "stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true "stage3_gather_16bit_weights_on_model_save": true
} }
} }

View File

@ -1,9 +1,9 @@
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.30.1,<=0.34.2 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
gradio>=4.0.0 gradio>=4.0.0,<5.0.0
pandas>=2.0.0 pandas>=2.0.0
scipy scipy
einops einops
@ -19,3 +19,4 @@ fire
packaging packaging
pyyaml pyyaml
numpy<2.0.0 numpy<2.0.0
av

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# #
# This code is inspired by the Microsoft's DeepSpeed library. # This code is inspired by the Microsoft's DeepSpeed library.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team. # Copyright 2024 imoneoi and the LlamaFactory team.
# #
# This code is inspired by the imoneoi's OpenChat library. # This code is inspired by the imoneoi's OpenChat library.
@ -74,7 +73,7 @@ def calculate_lr(
elif stage == "sft": elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0 valid_tokens, total_tokens = 0, 0

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
elif "4090" in device_name: elif "4090" in device_name:
return 98 * 1e12 * world_size return 98 * 1e12 * world_size
else: else:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu( def calculate_mfu(
@ -140,10 +139,10 @@ def calculate_mfu(
"bf16": True, "bf16": True,
} }
if deepspeed_stage in [2, 3]: if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args) run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f) result = json.load(f)
if dist.is_initialized(): if dist.is_initialized():
@ -157,7 +156,7 @@ def calculate_mfu(
* compute_model_flops(model_name_or_path, total_batch_size, seq_length) * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size) / compute_device_flops(world_size)
) )
print("MFU: {:.2f}%".format(mfu_value * 100)) print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -100,7 +99,7 @@ def calculate_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
) )
else: else:
raise NotImplementedError("Stage does not supported: {}.".format(stage)) raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
@ -125,8 +124,8 @@ def calculate_ppl(
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2) json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities))) print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print("Perplexities have been saved at {}.".format(save_name)) print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -61,7 +60,7 @@ def length_cdf(
for length, count in length_tuples: for length, count in length_tuples:
count_accu += count count_accu += count
prob_accu += count / total_num * 100 prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval)) print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# #
# This code is inspired by the Tencent's LLaMA-Pro library. # This code is inspired by the Tencent's LLaMA-Pro library.
@ -40,7 +39,7 @@ if TYPE_CHECKING:
def change_name(name: str, old_index: int, new_index: int) -> str: def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index)) return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion( def block_expansion(
@ -76,27 +75,27 @@ def block_expansion(
state_dict = model.state_dict() state_dict = model.state_dict()
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand)) raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
output_state_dict = OrderedDict() output_state_dict = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} copied from layer {i}")
layer_cnt += 1 layer_cnt += 1
if (i + 1) % split == 0: if (i + 1) % split == 0:
for key, value in state_dict.items(): for key, value in state_dict.items():
if ".{:d}.".format(i) in key: if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key: if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else: else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i)) print(f"Add layer {layer_cnt} expanded from layer {i}")
layer_cnt += 1 layer_cnt += 1
for key, value in state_dict.items(): for key, value in state_dict.items():
@ -113,17 +112,17 @@ def block_expansion(
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze") print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand)) print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true") print("use_llama_pro: true")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
def save_config(input_dir: str, output_dir: str): def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f) llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_baichuan2( def llamafy_baichuan2(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
elif "lm_head" in key: elif "lm_head" in key:
llama2_state_dict[key] = value llama2_state_dict[key] = value
else: else:
raise KeyError("Unable to process key {}".format(key)) raise KeyError(f"Unable to process key {key}")
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
torch.save(shard, os.path.join(output_dir, shard_file)) torch.save(shard, os.path.join(output_dir, shard_file))
if index is None: if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
else: else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
return str(torch_dtype).replace("torch.", "") return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: Dict[str, Any] = OrderedDict()
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2) json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
def llamafy_qwen( def llamafy_qwen(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -70,19 +69,19 @@ def quantize_loftq(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir)) print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(loftq_dir)) print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits)) print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is based on the HuggingFace's PEFT library. # This code is based on the HuggingFace's PEFT library.
@ -54,7 +53,7 @@ def quantize_pissa(
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
target_modules=lora_target, target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter), init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
) )
# Init PiSSA model # Init PiSSA model
@ -65,17 +64,17 @@ def quantize_pissa(
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir)) setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir)) print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir)) print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print(f"model_name_or_path: {output_dir}")
print("adapter_name_or_path: {}".format(pissa_dir)) print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora") print("finetuning_type: lora")
print("pissa_init: false") print("pissa_init: false")
print("pissa_convert: true") print("pissa_convert: true")

65
scripts/test_image.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
messages = []
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "Output the color and number of each box."},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 1:", result.choices[0].message.content)
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": "What kind of flower is this?"},
{
"type": "image_url",
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
},
],
}
)
result = client.chat.completions.create(messages=messages, model="test")
messages.append(result.choices[0].message)
print("Round 2:", result.choices[0].message.content)
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
if __name__ == "__main__":
main()

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def get_version() -> str: def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f: with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content) (version,) = re.findall(pattern, file_content)
@ -28,7 +28,7 @@ def get_version() -> str:
def get_requires() -> List[str]: def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
@ -54,13 +54,14 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.0"], "vllm": ["vllm>=0.4.3,<=0.6.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"], "adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"dev": ["ruff", "pytest"], "openmind": ["openmind"],
"dev": ["pre-commit", "ruff", "pytest"],
} }
@ -71,7 +72,7 @@ def main():
author="hiyouga", author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn", author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License", license="Apache 2.0 License",

View File

@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def main(): def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -20,17 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.30.1,<=0.34.2 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.45.0 transformers>=4.41.2,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
""" """
from .extras.env import VERSION from .extras.env import VERSION

View File

@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "") root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY", None) api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo")) model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def run_api() -> None: def run_api() -> None:
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0") api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000")) api_port = int(os.getenv("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port)) print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port) uvicorn.run(app, host=api_host, port=api_port)

View File

@ -21,7 +21,7 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras import logging
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__) logger = logging.get_logger(__name__)
ROLE_MAPPING = { ROLE_MAPPING = {
Role.USER: DataRole.USER.value, Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value, Role.ASSISTANT: DataRole.ASSISTANT.value,
@ -69,8 +69,8 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None images = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -111,7 +111,7 @@ def _process_request(
else: # web uri else: # web uri
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB") images.append(Image.open(image_stream).convert("RGB"))
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
@ -124,7 +124,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, image return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async def create_chat_completion_response( async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
function = Function(name=tool[0], arguments=tool[1]) function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
@ -193,8 +193,8 @@ async def create_chat_completion_response(
async def create_stream_chat_completion_response( async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex) score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")

View File

@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""

View File

@ -53,7 +53,7 @@ class ChatModel:
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Gets a list of responses of the chat model. Gets a list of responses of the chat model.
""" """
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
) )
return task.result() return task.result()
@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Asynchronously gets a list of responses of the chat model. Asynchronously gets a list of responses of the chat model.
""" """
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""
Gets the response token-by-token of the chat model. Gets the response token-by-token of the chat model.
""" """
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
Asynchronously gets the response token-by-token of the chat model. Asynchronously gets the response token-by-token of the chat model.
""" """
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try: try:
asyncio.get_event_loop() asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
logger.warning("There is no current event loop, creating a new one.") logger.warning_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None: if images is not None:
mm_input_dict.update({"images": [image], "imglens": [1]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None: if videos is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]}) mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]: if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages( messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], processor
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.") logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value) if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:

View File

@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from typing_extensions import override from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
@ -43,7 +43,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine): class VllmEngine(BaseEngine):
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex) request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None: if images is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features if images is not None: # add image features
if not isinstance(image, (str, ImageObject)): image_data = []
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image))) for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image).convert("RGB") image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image} image_data.append(image)
multi_modal_data = {"image": image_data}
else: else:
multi_modal_data = None multi_modal_data = None
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text

View File

@ -22,8 +22,8 @@ from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
@ -47,7 +47,7 @@ USAGE = (
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION) + f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION)) + " " * (21 - len(VERSION))
+ "|\n|" + "|\n|"
+ " " * 56 + " " * 56
@ -56,7 +56,7 @@ WELCOME = (
+ "-" * 58 + "-" * 58
) )
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
@ -86,19 +86,19 @@ def main():
elif command == Command.EXPORT: elif command == Command.EXPORT:
export_model() export_model()
elif command == Command.TRAIN: elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1: if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run( process = subprocess.run(
( (
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}" "--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format( ).format(
nnodes=os.environ.get("NNODES", "1"), nnodes=os.getenv("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"), node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
file_name=launcher.__file__, file_name=launcher.__file__,
@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}.".format(command)) raise NotImplementedError(f"Unknown command: {command}.")

View File

@ -16,7 +16,7 @@ import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
@ -29,45 +29,51 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _convert_images( def _convert_images(
images: Sequence["ImageInput"], images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["ImageInput"]]: ) -> Optional[List["ImageInput"]]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
if len(images) == 0: if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None return None
else:
images = images[:]
images = images[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)): for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])): if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i]) images[i] = os.path.join(data_args.image_dir, images[i])
return images return images
def _convert_videos( def _convert_videos(
videos: Sequence["VideoInput"], videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
data_args: "DataArguments", data_args: "DataArguments",
) -> Optional[List["VideoInput"]]: ) -> Optional[List["VideoInput"]]:
r""" r"""
Optionally concatenates video path to dataset dir when loading from local disk. Optionally concatenates video path to dataset dir when loading from local disk.
""" """
if len(videos) == 0: if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None return None
else:
videos = videos[:]
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)): for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])): if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i]) videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos return videos
@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data = False broken_data = False
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages)) logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True broken_data = True
aligned_messages.append( aligned_messages.append(
@ -171,7 +177,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0 dataset_attr.ranking and len(aligned_messages) % 2 == 0
): ):
logger.warning("Invalid message count in {}.".format(messages)) logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
@ -192,7 +198,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1] chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1]
): ):
logger.warning("Invalid role tag in {}.".format([chosen, rejected])) logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True broken_data = True
prompt = aligned_messages prompt = aligned_messages
@ -205,7 +211,7 @@ def convert_sharegpt(
response = aligned_messages[-1:] response = aligned_messages[-1:]
if broken_data: if broken_data:
logger.warning("Skipping this abnormal example.") logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], [] prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)

View File

@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features) features: Dict[str, "torch.Tensor"] = super().__call__(features)
features.update(mm_inputs) features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features return features
@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for key in ("chosen", "rejected"): for key in ("chosen", "rejected"):
for feature in features: for feature in features:
target_feature = { target_feature = {
"input_ids": feature["{}_input_ids".format(key)], "input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature[f"{key}_attention_mask"],
"labels": feature["{}_labels".format(key)], "labels": feature[f"{key}_labels"],
"images": feature["images"], "images": feature["images"],
"videos": feature["videos"], "videos": feature["videos"],
} }

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments from ..hparams import DataArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
if data_args.streaming: if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.") logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets) return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"): elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets( return interleave_datasets(
datasets=all_datasets, datasets=all_datasets,
@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy)) raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset( def split_dataset(

View File

@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str): if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value)) raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError: except json.JSONDecodeError:
functions = [] raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for name, arguments in functions:
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements return elements
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -20,8 +20,8 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data from ..extras.misc import has_tokenized_data
from .aligner import align_dataset from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
@ -39,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _load_single_dataset( def _load_single_dataset(
@ -51,9 +51,9 @@ def _load_single_dataset(
r""" r"""
Loads a single dataset and aligns it to the standard format. Loads a single dataset and aligns it to the standard format.
""" """
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File {} not found.".format(local_path)) raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
@ -98,10 +97,27 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else: else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
@ -111,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
@ -128,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {} dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict: if "train" in dataset_dict:
@ -277,8 +290,8 @@ def get_dataset(
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)

View File

@ -4,6 +4,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np import numpy as np
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@ -110,7 +111,7 @@ class BasePlugin:
image = Image.open(image["path"]) image = Image.open(image["path"])
if not isinstance(image, ImageObject): if not isinstance(image, ImageObject):
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs)) results.append(self._preprocess_image(image, **kwargs))
@ -157,6 +158,7 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key input_dict = {"images": None} # default key
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -174,10 +176,16 @@ class BasePlugin:
) )
input_dict["videos"] = videos input_dict["videos"] = videos
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: mm_inputs = {}
return image_processor(**input_dict, return_tensors="pt") if image_processor != video_processor:
else: if input_dict.get("images") is not None:
return {} mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
return mm_inputs
def process_messages( def process_messages(
self, self,
@ -218,6 +226,14 @@ class BasePlugin:
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r""" r"""
Builds batched multimodal inputs for VLMs. Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
seqlens: number of tokens in each sample, shape (batch_size,)
processor: a processor for pre-processing images and videos
""" """
self._validate_input(images, videos) self._validate_input(images, videos)
return {} return {}
@ -245,7 +261,123 @@ class LlavaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
res = self._get_mm_inputs(images, videos, processor)
return res
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -284,7 +416,7 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", "") message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages
@ -324,6 +456,68 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override @override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@ -369,7 +563,7 @@ class Qwen2vlPlugin(BasePlugin):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw): if num_image_tokens >= len(image_grid_thw):
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
@ -382,7 +576,7 @@ class Qwen2vlPlugin(BasePlugin):
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw): if num_video_tokens >= len(video_grid_thw):
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, VIDEO_PLACEHOLDER,
@ -396,10 +590,73 @@ class Qwen2vlPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
exist_images = "pixel_values_images" in mm_inputs
exist_videos = "pixel_values_videos" in mm_inputs
if exist_videos or exist_images:
if exist_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if exist_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
for message in messages:
content = message["content"]
while self.image_token in content:
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1)
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
content = content.replace("{{image}}", self.image_token * image_seqlen)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
return messages return messages
@ -420,8 +677,12 @@ class Qwen2vlPlugin(BasePlugin):
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin, "paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin, "qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
} }
@ -432,6 +693,6 @@ def get_mm_plugin(
) -> "BasePlugin": ) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None) plugin_class = PLUGINS.get(name, None)
if plugin_class is None: if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name)) raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token) return plugin_class(image_token, video_token)

View File

@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope, use_openmind
@dataclass @dataclass
@ -30,7 +30,7 @@ class DatasetAttr:
""" """
# basic configs # basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca" formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False ranking: bool = False
@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path = os.path.join(dataset_dir, DATA_CONFIG) config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(config_path, "r") as f: with open(config_path) as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None dataset_info = None
dataset_list: List["DatasetAttr"] = [] dataset_list: List["DatasetAttr"] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub" if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name] has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name] has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url: if has_hf_url or has_ms_url or has_om_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url): if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else: else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_feedback_example( def _encode_feedback_example(
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0: if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.") logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs return model_inputs

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_pairwise_example( def _encode_pairwise_example(
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"])) print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))) print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"])) print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False))) print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"])) print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False))) print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"])) print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False))) print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")

View File

@ -15,8 +15,8 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen from .processor_utils import greedy_knapsack, infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_supervised_example( def _encode_supervised_example(
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list) length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
) )
length = len(input_ids) length = len(input_ids)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else: else:
lengths.append(length) lengths.append(length)
length2indexes[length].append(valid_num) length2indexes[length].append(valid_num)
@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")

View File

@ -15,7 +15,7 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_unsupervised_example( def _encode_unsupervised_example(
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1: if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from typing_extensions import override from typing_extensions import override
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
@ -32,7 +32,7 @@ if TYPE_CHECKING:
from .mm_plugin import BasePlugin from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
@ -49,6 +49,7 @@ class Template:
stop_words: List[str] stop_words: List[str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
@ -146,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids return token_ids
@ -214,6 +215,7 @@ def _register_template(
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
@ -263,6 +265,7 @@ def _register_template(
stop_words=stop_words, stop_words=stop_words,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added: if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else: else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
@ -353,24 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
)
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(data_args.template, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template)) raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None: if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format)) logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
@ -388,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
try: if tokenizer.chat_template is None or template.replace_jinja_template:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) try:
except ValueError: tokenizer.chat_template = _get_jinja_template(template, tokenizer)
logger.info("Cannot add this chat template to tokenizer.") except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template return template
@ -640,6 +641,14 @@ _register_template(
) )
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@ -664,6 +673,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=False,
) )
@ -681,6 +691,14 @@ _register_template(
) )
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
@ -740,6 +758,7 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -754,6 +773,107 @@ _register_template(
) )
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@ -831,6 +951,14 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -840,6 +968,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
) )
@ -852,6 +981,7 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
@ -907,6 +1037,17 @@ _register_template(
) )
_register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),

View File

@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils": def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None) tool_utils = TOOLS.get(name, None)
if tool_utils is None: if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name)) raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils return tool_utils

View File

@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items() for category_name, category_correct in category_corrects.items()
if len(category_correct) if len(category_correct)
] ]

View File

@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, f"Template {name} does not exist."
return eval_template return eval_template

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>" IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>" VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
@ -107,6 +108,7 @@ VISION_MODELS = set()
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
OPENMIND = "om"
def register_model_group( def register_model_group(
@ -114,17 +116,12 @@ def register_model_group(
template: Optional[str] = None, template: Optional[str] = None,
vision: bool = False, vision: bool = False,
) -> None: ) -> None:
prefix = None
for name, path in models.items(): for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path SUPPORTED_MODELS[name] = path
if template is not None: if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
DEFAULT_TEMPLATE[prefix] = template DEFAULT_TEMPLATE[name] = template
if vision: if vision:
VISION_MODELS.add(prefix) VISION_MODELS.add(name)
register_model_group( register_model_group(
@ -168,14 +165,17 @@ register_model_group(
"Baichuan2-13B-Base": { "Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
}, },
"Baichuan2-7B-Chat": { "Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
}, },
"Baichuan2-13B-Chat": { "Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
}, },
}, },
template="baichuan2", template="baichuan2",
@ -274,27 +274,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ChineseLLaMA2-1.3B": { "Chinese-Llama-2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
}, },
"ChineseLLaMA2-7B": { "Chinese-Llama-2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
}, },
"ChineseLLaMA2-13B": { "Chinese-Llama-2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
}, },
"ChineseLLaMA2-1.3B-Chat": { "Chinese-Alpaca-2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
}, },
"ChineseLLaMA2-7B-Chat": { "Chinese-Alpaca-2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
}, },
"ChineseLLaMA2-13B-Chat": { "Chinese-Alpaca-2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
}, },
@ -450,25 +450,25 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeekCoder-6.7B-Base": { "DeepSeek-Coder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
}, },
"DeepSeekCoder-7B-Base": { "DeepSeek-Coder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
}, },
"DeepSeekCoder-33B-Base": { "DeepSeek-Coder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
}, },
"DeepSeekCoder-6.7B-Instruct": { "DeepSeek-Coder-6.7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
}, },
"DeepSeekCoder-7B-Instruct": { "DeepSeek-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
}, },
"DeepSeekCoder-33B-Instruct": { "DeepSeek-Coder-33B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
}, },
@ -477,6 +477,16 @@ register_model_group(
) )
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
},
},
template="exaone",
)
register_model_group( register_model_group(
models={ models={
"Falcon-7B": { "Falcon-7B": {
@ -550,10 +560,12 @@ register_model_group(
"Gemma-2-2B-Instruct": { "Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it", DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
}, },
"Gemma-2-9B-Instruct": { "Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
}, },
"Gemma-2-27B-Instruct": { "Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
@ -573,6 +585,7 @@ register_model_group(
"GLM-4-9B-Chat": { "GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
}, },
"GLM-4-9B-1M-Chat": { "GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
@ -583,6 +596,33 @@ register_model_group(
) )
register_model_group(
models={
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
},
"Index-1.9B-Base-Pure": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
},
},
template="index",
)
register_model_group( register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
@ -624,16 +664,10 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
}, },
},
template="intern2",
)
register_model_group(
models={
"InternLM2.5-1.8B": { "InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
}, },
"InternLM2.5-7B": { "InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b", DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
@ -642,22 +676,27 @@ register_model_group(
"InternLM2.5-20B": { "InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b", DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
}, },
"InternLM2.5-1.8B-Chat": { "InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
}, },
"InternLM2.5-7B-Chat": { "InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
}, },
"InternLM2.5-7B-1M-Chat": { "InternLM2.5-7B-1M-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
}, },
"InternLM2.5-20B-Chat": { "InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
}, },
}, },
template="intern2", template="intern2",
@ -686,19 +725,19 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA-7B": { "Llama-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
}, },
"LLaMA-13B": { "Llama-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b", DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b", DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
}, },
"LLaMA-30B": { "Llama-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b", DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b", DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
}, },
"LLaMA-65B": { "Llama-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b", DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b", DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
}, },
@ -708,27 +747,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA2-7B": { "Llama-2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
}, },
"LLaMA2-13B": { "Llama-2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
}, },
"LLaMA2-70B": { "Llama-2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
}, },
"LLaMA2-7B-Chat": { "Llama-2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
}, },
"LLaMA2-13B-Chat": { "Llama-2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
}, },
"LLaMA2-70B-Chat": { "Llama-2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
}, },
@ -739,60 +778,78 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA3-8B": { "Llama-3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
}, },
"LLaMA3-70B": { "Llama-3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
}, },
"LLaMA3-8B-Instruct": { "Llama-3-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
}, },
"LLaMA3-70B-Instruct": { "Llama-3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
}, },
"LLaMA3-8B-Chinese-Chat": { "Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
}, },
"LLaMA3-70B-Chinese-Chat": { "Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
}, },
}, "Llama-3.1-8B": {
template="llama3",
)
register_model_group(
models={
"LLaMA3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
}, },
"LLaMA3.1-70B": { "Llama-3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
}, },
"LLaMA3.1-405B": { "Llama-3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
}, },
"LLaMA3.1-8B-Instruct": { "Llama-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
}, },
"LLaMA3.1-70B-Instruct": { "Llama-3.1-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
}, },
"LLaMA3.1-405B-Instruct": { "Llama-3.1-405B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
}, },
"Llama-3.1-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
},
"Llama-3.1-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
},
"Llama-3.2-1B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
},
"Llama-3.2-3B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
},
"Llama-3.2-1B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
},
"Llama-3.2-3B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
},
}, },
template="llama3", template="llama3",
) )
@ -800,11 +857,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaVA1.5-7B-Chat": { "LLaVA-1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
}, },
"LLaVA1.5-13B-Chat": { "LLaVA-1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
}, },
}, },
template="llava", template="llava",
@ -812,6 +871,117 @@ register_model_group(
) )
register_model_group(
models={
"LLaVA-NeXT-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
},
"LLaVA-NeXT-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
},
},
template="llava_next",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Mistral-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
},
},
template="llava_next_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Llama3-8B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
},
},
template="llava_next_llama3",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
},
},
template="llava_next_yi",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-72B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
},
"LLaVA-NeXT-110B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
},
},
template="llava_next_qwen",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
},
"LLaVA-NeXT-Video-7B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
},
},
template="llava_next_video",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-32k-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
},
},
template="llava_next_video_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
},
"LLaVA-NeXT-Video-34B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
},
},
template="llava_next_video_yi",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
@ -832,6 +1002,7 @@ register_model_group(
"MiniCPM3-4B-Chat": { "MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B", DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
}, },
}, },
template="cpm3", template="cpm3",
@ -1005,27 +1176,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Phi3-4B-4k-Instruct": { "Phi-3-4B-4k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
}, },
"Phi3-4B-128k-Instruct": { "Phi-3-4B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
}, },
"Phi3-7B-8k-Instruct": { "Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
}, },
"Phi3-7B-128k-Instruct": { "Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
}, },
"Phi3-14B-8k-Instruct": { "Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
}, },
"Phi3-14B-128k-Instruct": { "Phi-3-14B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
}, },
@ -1034,6 +1205,18 @@ register_model_group(
) )
register_model_group(
models={
"Pixtral-12B-Chat": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {
@ -1068,35 +1251,35 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
}, },
"Qwen-1.8B-int8-Chat": { "Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
}, },
"Qwen-1.8B-int4-Chat": { "Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
}, },
"Qwen-7B-int8-Chat": { "Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
}, },
"Qwen-7B-int4-Chat": { "Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
}, },
"Qwen-14B-int8-Chat": { "Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
}, },
"Qwen-14B-int4-Chat": { "Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
}, },
"Qwen-72B-int8-Chat": { "Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
}, },
"Qwen-72B-int4-Chat": { "Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
}, },
@ -1179,75 +1362,75 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
}, },
"Qwen1.5-0.5B-int8-Chat": { "Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-0.5B-int4-Chat": { "Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
}, },
"Qwen1.5-1.8B-int8-Chat": { "Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-1.8B-int4-Chat": { "Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
}, },
"Qwen1.5-4B-int8-Chat": { "Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-4B-int4-Chat": { "Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
}, },
"Qwen1.5-7B-int8-Chat": { "Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-7B-int4-Chat": { "Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
}, },
"Qwen1.5-14B-int8-Chat": { "Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-14B-int4-Chat": { "Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
}, },
"Qwen1.5-32B-int4-Chat": { "Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
}, },
"Qwen1.5-72B-int8-Chat": { "Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-72B-int4-Chat": { "Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
}, },
"Qwen1.5-110B-int4-Chat": { "Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
}, },
"Qwen1.5-MoE-A2.7B-int4-Chat": { "Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
}, },
"Qwen1.5-Code-7B": { "CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
}, },
"Qwen1.5-Code-7B-Chat": { "CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
}, },
"Qwen1.5-Code-7B-int4-Chat": { "CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
}, },
@ -1281,14 +1464,17 @@ register_model_group(
"Qwen2-0.5B-Instruct": { "Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
}, },
"Qwen2-1.5B-Instruct": { "Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
}, },
"Qwen2-7B-Instruct": { "Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
}, },
"Qwen2-72B-Instruct": { "Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
@ -1568,51 +1754,53 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2VL-2B-Instruct": { "Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
}, },
"Qwen2VL-7B-Instruct": { "Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
}, },
"Qwen2VL-72B-Instruct": { "Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
}, },
"Qwen2VL-2B-Instruct-GPTQ-Int8": { "Qwen2-VL-2B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-2B-Instruct-GPTQ-Int4": { "Qwen2-VL-2B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-2B-Instruct-AWQ": { "Qwen2-VL-2B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
}, },
"Qwen2VL-7B-Instruct-GPTQ-Int8": { "Qwen2-VL-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-7B-Instruct-GPTQ-Int4": { "Qwen2-VL-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-7B-Instruct-AWQ": { "Qwen2-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
}, },
"Qwen2VL-72B-Instruct-GPTQ-Int8": { "Qwen2-VL-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
}, },
"Qwen2VL-72B-Instruct-GPTQ-Int4": { "Qwen2-VL-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
}, },
"Qwen2VL-72B-Instruct-AWQ": { "Qwen2-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
}, },
@ -1673,10 +1861,12 @@ register_model_group(
"TeleChat-7B-Chat": { "TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B", DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
}, },
"TeleChat-12B-Chat": { "TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
}, },
"TeleChat-12B-v2-Chat": { "TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
@ -1689,11 +1879,11 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Vicuna1.5-7B-Chat": { "Vicuna-v1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
}, },
"Vicuna1.5-13B-Chat": { "Vicuna-v1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
}, },
@ -1702,6 +1892,17 @@ register_model_group(
) )
register_model_group(
models={
"Video-LLaVA-7B-Chat": {
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
},
},
template="video_llava",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"XuanYuan-6B": { "XuanYuan-6B": {
@ -1712,7 +1913,7 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
}, },
"XuanYuan-2-70B": { "XuanYuan2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
}, },
@ -1724,31 +1925,31 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
}, },
"XuanYuan-2-70B-Chat": { "XuanYuan2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
}, },
"XuanYuan-6B-int8-Chat": { "XuanYuan-6B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
}, },
"XuanYuan-6B-int4-Chat": { "XuanYuan-6B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
}, },
"XuanYuan-70B-int8-Chat": { "XuanYuan-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
}, },
"XuanYuan-70B-int4-Chat": { "XuanYuan-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
}, },
"XuanYuan-2-70B-int8-Chat": { "XuanYuan2-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
}, },
"XuanYuan-2-70B-int4-Chat": { "XuanYuan2-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
}, },
@ -1853,19 +2054,19 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
}, },
"Yi-6B-int8-Chat": { "Yi-6B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
}, },
"Yi-6B-int4-Chat": { "Yi-6B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
}, },
"Yi-34B-int8-Chat": { "Yi-34B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
}, },
"Yi-34B-int4-Chat": { "Yi-34B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
}, },
@ -1884,6 +2085,7 @@ register_model_group(
"Yi-1.5-6B-Chat": { "Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
}, },
"Yi-1.5-9B-Chat": { "Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
@ -1916,10 +2118,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"YiVL-6B-Chat": { "Yi-VL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
}, },
"YiVL-34B-Chat": { "Yi-VL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
}, },
}, },

View File

@ -72,4 +72,4 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")

View File

@ -20,6 +20,7 @@ import os
import sys import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
formatter = logging.Formatter( self._formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
) )
self.setLevel(logging.INFO) self.setLevel(logging.INFO)
self.setFormatter(formatter)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG) self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log): if os.path.exists(self.running_log):
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
if record.name == "httpx": if record.name == "httpx":
return return
log_entry = self.format(record) log_entry = self._formatter.format(record)
self.thread_pool.submit(self._write_log, log_entry) self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None: def close(self) -> None:
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
"""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r""" r"""
Returns the default logging level. Returns the default logging level.
@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()] return logging._nameToLevel[env_level_str.upper()]
else: else:
raise ValueError("Unknown logging level: {}.".format(env_level_str)) raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level return _default_log_level
@ -84,7 +99,7 @@ def _get_library_name() -> str:
return __name__.split(".")[0] return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger": def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name()) return logging.getLogger(_get_library_name())
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
global _default_handler global _default_handler
with _thread_lock: with _thread_lock:
if _default_handler: if _default_handler: # already configured
return return
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
_default_handler = logging.StreamHandler(sys.stdout) _default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter) _default_handler.setFormatter(formatter)
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger": def get_logger(name: Optional[str] = None) -> "_Logger":
r""" r"""
Returns a logger with the specified name. It it not supposed to be accessed externally. Returns a logger with the specified name. It it not supposed to be accessed externally.
""" """
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
_configure_library_root_logger() _configure_library_root_logger()
return logging.getLogger(name) return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once

View File

@ -32,7 +32,7 @@ from transformers.utils import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from .logging import get_logger from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
@ -48,7 +48,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class AverageMeter: class AverageMeter:
@ -76,12 +76,12 @@ def check_dependencies() -> None:
r""" r"""
Checks the version of the required packages. Checks the version of the required packages.
""" """
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
@ -231,18 +231,35 @@ def torch_gc() -> None:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def try_download_model_from_ms(model_args: "ModelArguments") -> str: def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path): if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
return model_args.model_name_or_path return model_args.model_name_or_path
try: if use_modelscope():
from modelscope import snapshot_download require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) return snapshot_download(
except ImportError: model_args.model_name_or_path,
raise ImportError("Please install modelscope via `pip install modelscope -U`") revision=revision,
cache_dir=model_args.cache_dir,
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,
revision=model_args.model_revision,
cache_dir=model_args.cache_dir,
)
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]

View File

@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0") return _get_package_version("transformers") >= version.parse("4.43.0")
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available(): def is_uvicorn_available():
return _is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@ -19,7 +19,7 @@ from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger from . import logging
from .packages import is_matplotlib_available from .packages import is_matplotlib_available
@ -28,7 +28,7 @@ if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]: def smooth(scalars: List[float]) -> List[float]:
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for key in keys: for key in keys:
@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics.append(data["log_history"][i][key]) metrics.append(data["log_history"][i][key])
if len(metrics) == 0: if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.") logger.warning_rank0(f"No metric {key} to plot.")
continue continue
plt.figure() plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary)) plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()

View File

@ -41,6 +41,10 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
image_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
)
cutoff_len: int = field( cutoff_len: int = field(
default=1024, default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
@ -111,7 +115,13 @@ class DataArguments:
) )
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the tokenized datasets."}, metadata={
"help": (
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
) )
def __post_init__(self): def __post_init__(self):
@ -123,6 +133,9 @@ class DataArguments:
self.dataset = split_arg(self.dataset) self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset) self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6: if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.") raise ValueError("Cannot specify `val_size` if `dataset` is None.")

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -267,6 +267,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
@ -308,20 +312,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": def copyfrom(cls, source: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() init_args, lazy_args = {}, {}
arg_dict.update(**kwargs) for attr in fields(source):
for attr in fields(cls): if attr.init:
if not attr.init: init_args[attr.name] = getattr(source, attr.name)
arg_dict.pop(attr.name) else:
lazy_args[attr.name] = getattr(source, attr.name)
new_arg = cls(**arg_dict) init_args.update(kwargs)
new_arg.compute_dtype = old_arg.compute_dtype result = cls(**init_args)
new_arg.device_map = old_arg.device_map for name, value in lazy_args.items():
new_arg.model_max_length = old_arg.model_max_length setattr(result, name, value)
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg return result

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
@ -104,7 +103,7 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer: if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
@ -123,7 +122,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0") require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing: if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.") logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and model_args.resize_vocab and model_args.resize_vocab
and finetuning_args.additional_target is None and finetuning_args.additional_target is None
): ):
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.") logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None: if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.") logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments # Post-process training arguments
if ( if (
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.") logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None training_args.resume_from_checkpoint = None
else: else:
can_resume_from_checkpoint = True can_resume_from_checkpoint = True
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None: if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if ( if (
finetuning_args.stage in ["rm", "ppo"] finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning( logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint training_args.resume_from_checkpoint
) )

View File

@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _setup_full_tuning( def _setup_full_tuning(
@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Full") logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules): if not any(forbidden_module in name for forbidden_module in forbidden_modules):
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Freeze") logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config") config = getattr(model.config, "text_config")
else: else:
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning( def _setup_lora_tuning(
@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
if is_trainable: if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
@ -190,7 +190,7 @@ def _setup_lora_tuning(
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
@ -236,11 +236,11 @@ def _setup_lora_tuning(
else: else:
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1: if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.") logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa" peft_kwargs["init_lora_weights"] = "pissa"
else: else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable: if not is_trainable:
pass pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam: elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else: else:
logger.info("Upcasting trainable params to float32.") logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
else: else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model return model

View File

@ -18,15 +18,15 @@ import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict): class TokenizerModule(TypedDict):
@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
skip_check_imports() skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args) model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return { return {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""
Loads pretrained tokenizer. Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
@ -82,33 +82,30 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right", padding_side="right",
**init_kwargs, **init_kwargs,
) )
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.new_special_tokens is not None: if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens), dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False, replace_additional_special_tokens=False,
) )
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab: if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer) patch_processor(processor, config, tokenizer, model_args)
setattr(processor, "image_seqlen", get_image_seqlen(config)) except Exception as e:
setattr(processor, "image_resolution", model_args.image_resolution) logger.debug(f"Processor was not found: {e}.")
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
except Exception:
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if "Processor" not in processor.__class__.__name__: if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None processor = None
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
@ -135,6 +132,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None model = None
lazy_load = False lazy_load = False
@ -157,7 +155,7 @@ def load_model(
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
if model_args.train_from_scratch: if model_args.train_from_scratch:
model = load_class.from_config(config) model = load_class.from_config(config, trust_remote_code=True)
else: else:
model = load_class.from_pretrained(**init_kwargs) model = load_class.from_pretrained(**init_kwargs)
@ -182,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args) vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
@ -200,9 +198,9 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param trainable_params, all_param, 100 * trainable_params / all_param
) )
else: else:
param_stats = "all params: {:,}".format(all_param) param_stats = f"all params: {all_param:,}"
logger.info(param_stats) logger.info_rank0(param_stats)
if model_args.print_param_status: if model_args.print_param_status:
for name, param in model.named_parameters(): for name, param in model.named_parameters():

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(
@ -37,13 +37,16 @@ def configure_attn_implementation(
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") if model_args.flash_attn != "fa2":
model_args.flash_attn = "fa2" logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else: else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return
@ -53,18 +56,18 @@ def configure_attn_implementation(
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available(): if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.") logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.") logger.warning_rank0("FlashAttention-2 is not installed.")
return return
requested_attn_implementation = "flash_attention_2" requested_attn_implementation = "flash_attention_2"
else: else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation) setattr(config, "attn_implementation", requested_attn_implementation)
@ -79,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation = getattr(config, "_attn_implementation", None) attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2": if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.") logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa": elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.") logger.info_rank0("Using torch SDPA for faster training and inference.")
else: else:
logger.info("Using vanilla attention implementation.") logger.info_rank0("Using vanilla attention implementation.")

View File

@ -19,14 +19,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable: def get_unsloth_gradient_checkpointing_func() -> Callable:
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers. Only applies gradient checkpointing to trainable layers.
""" """
@wraps(gradient_checkpointing_func) @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing: if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32 (3) add the upcasting of the lm_head in fp32
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.") logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing: if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.") logger.warning_rank0("Current model does not support gradient checkpointing.")
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output: if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings() output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.") logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook) output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
@ -23,10 +24,15 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel: if not is_trainable or not model_args.enable_liger_kernel:
return return
@ -48,8 +54,14 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else: else:
logger.warning("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
apply_liger_kernel() if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Liger kernel has been applied to the model.") logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
Cache, Cache,
LlamaAttention, LlamaAttention,
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43 from ...extras.packages import is_transformers_version_greater_than_4_43
@ -44,7 +44,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from: # Modified from:
@ -86,7 +86,7 @@ def llama_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn: if not is_trainable or not model_args.shift_attn:
return return
logger = get_logger(__name__) logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch() _apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.") logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else: else:
logger.warning("Current model does not support shift short attention.") logger.warning_rank0("Current model does not support shift short attention.")

View File

@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer") forbidden_modules.add("output_layer")
elif model_type == "internlm2": elif model_type == "internlm2":
forbidden_modules.add("output") forbidden_modules.add("output")
elif model_type in ["llava", "paligemma"]: elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
forbidden_modules.add("merger") forbidden_modules.add("merger")
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names) return list(module_names)
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0: if num_layers % num_layer_trainable != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
) )
stride = num_layers // num_layer_trainable stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = [] module_names = []
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any( if any(target_module in name for target_module in target_modules) and any(
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
): ):
module_names.append(name) module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names return module_names

Some files were not shown because too many files have changed in this diff Show More