mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge branch 'main' into main
Former-commit-id: 5f14910910154ba569435e7e68acbd6c30f79e80
This commit is contained in:
commit
d99e164cad
@ -7,6 +7,8 @@ data
|
|||||||
docker
|
docker
|
||||||
saves
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
|
ms_cache
|
||||||
|
om_cache
|
||||||
output
|
output
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
|
17
.env.local
17
.env.local
@ -1,32 +1,35 @@
|
|||||||
# Note: actually we do not support .env, just for reference
|
# Note: actually we do not support .env, just for reference
|
||||||
# api
|
# api
|
||||||
API_HOST=0.0.0.0
|
API_HOST=
|
||||||
API_PORT=8000
|
API_PORT=
|
||||||
API_KEY=
|
API_KEY=
|
||||||
API_MODEL_NAME=gpt-3.5-turbo
|
API_MODEL_NAME=
|
||||||
FASTAPI_ROOT_PATH=
|
FASTAPI_ROOT_PATH=
|
||||||
|
MAX_CONCURRENT=
|
||||||
# general
|
# general
|
||||||
DISABLE_VERSION_CHECK=
|
DISABLE_VERSION_CHECK=
|
||||||
FORCE_CHECK_IMPORTS=
|
FORCE_CHECK_IMPORTS=
|
||||||
LLAMAFACTORY_VERBOSITY=
|
LLAMAFACTORY_VERBOSITY=
|
||||||
USE_MODELSCOPE_HUB=
|
USE_MODELSCOPE_HUB=
|
||||||
|
USE_OPENMIND_HUB=
|
||||||
RECORD_VRAM=
|
RECORD_VRAM=
|
||||||
# torchrun
|
# torchrun
|
||||||
FORCE_TORCHRUN=
|
FORCE_TORCHRUN=
|
||||||
MASTER_ADDR=
|
MASTER_ADDR=
|
||||||
MASTER_PORT=
|
MASTER_PORT=
|
||||||
NNODES=
|
NNODES=
|
||||||
RANK=
|
NODE_RANK=
|
||||||
NPROC_PER_NODE=
|
NPROC_PER_NODE=
|
||||||
# wandb
|
# wandb
|
||||||
WANDB_DISABLED=
|
WANDB_DISABLED=
|
||||||
WANDB_PROJECT=huggingface
|
WANDB_PROJECT=
|
||||||
WANDB_API_KEY=
|
WANDB_API_KEY=
|
||||||
# gradio ui
|
# gradio ui
|
||||||
GRADIO_SHARE=False
|
GRADIO_SHARE=
|
||||||
GRADIO_SERVER_NAME=0.0.0.0
|
GRADIO_SERVER_NAME=
|
||||||
GRADIO_SERVER_PORT=
|
GRADIO_SERVER_PORT=
|
||||||
GRADIO_ROOT_PATH=
|
GRADIO_ROOT_PATH=
|
||||||
|
GRADIO_IPV6=
|
||||||
# setup
|
# setup
|
||||||
ENABLE_SHORT_CONSOLE=1
|
ENABLE_SHORT_CONSOLE=1
|
||||||
# reserved (do not use)
|
# reserved (do not use)
|
||||||
|
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
|||||||
### Style guide
|
### Style guide
|
||||||
|
|
||||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
|
|
||||||
|
### Create a Pull Request
|
||||||
|
|
||||||
|
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||||
|
|
||||||
|
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||||
|
cd LLaMA-Factory
|
||||||
|
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create a new branch to hold your development changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout -b dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Set up a development environment by running the following command in a virtual environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||||
|
|
||||||
|
5. Check code before commit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make commit
|
||||||
|
make style && make quality
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Submit changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add .
|
||||||
|
git commit -m "commit message"
|
||||||
|
git fetch upstream
|
||||||
|
git rebase upstream/main
|
||||||
|
git push -u origin dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||||
|
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@ -26,15 +26,15 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install build
|
python -m pip install build
|
||||||
|
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: |
|
run: |
|
||||||
python -m build
|
python -m build
|
||||||
|
|
||||||
- name: Publish package
|
- name: Publish package
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
- "3.8"
|
- "3.8" # TODO: remove py38 in next transformers release
|
||||||
- "3.9"
|
- "3.9"
|
||||||
- "3.10"
|
- "3.10"
|
||||||
- "3.11"
|
- "3.11"
|
||||||
@ -54,7 +54,6 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install git+https://github.com/huggingface/transformers.git
|
|
||||||
python -m pip install ".[torch,dev]"
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -162,6 +162,7 @@ cython_debug/
|
|||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
ms_cache/
|
ms_cache/
|
||||||
hf_cache/
|
hf_cache/
|
||||||
|
om_cache/
|
||||||
cache/
|
cache/
|
||||||
config/
|
config/
|
||||||
saves/
|
saves/
|
||||||
|
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=25000']
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
args: ['--branch', 'main']
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.17.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py38-plus]
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.9
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
11
Makefile
11
Makefile
@ -1,7 +1,14 @@
|
|||||||
.PHONY: quality style test
|
.PHONY: build commit quality style test
|
||||||
|
|
||||||
check_dirs := scripts src tests setup.py
|
check_dirs := scripts src tests setup.py
|
||||||
|
|
||||||
|
build:
|
||||||
|
pip install build && python -m build
|
||||||
|
|
||||||
|
commit:
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
@ -11,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||||
|
104
README.md
104
README.md
@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#projects-using-llama-factory)
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -26,10 +26,17 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
|||||||
Choose your path:
|
Choose your path:
|
||||||
|
|
||||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
- **Local machine**: Please refer to [usage](#getting-started)
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
|
|
||||||
|
Recent activities:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [Features](#features)
|
- [Features](#features)
|
||||||
@ -72,6 +79,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||||
|
|
||||||
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||||
|
|
||||||
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||||
@ -130,7 +139,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||||
|
|
||||||
@ -162,36 +171,39 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi-small |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
@ -412,7 +424,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
|
|||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||||
|
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||||
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||||
|
>
|
||||||
|
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### Download from ModelScope Hub
|
### Download from ModelScope Hub
|
||||||
|
|
||||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
|||||||
|
|
||||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||||
|
|
||||||
|
### Download from Modelers Hub
|
||||||
|
|
||||||
|
You can also use Modelers Hub to download models and datasets.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
@ -684,11 +712,13 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -696,7 +726,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
99
README_zh.md
99
README_zh.md
@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#使用了-llama-factory-的项目)
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -26,11 +26,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
选择你的打开方式:
|
选择你的打开方式:
|
||||||
|
|
||||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||||
- **PAI-DSW**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
- **PAI-DSW**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
- **本地机器**:请见[如何使用](#如何使用)
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
|
|
||||||
|
近期活动:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [项目特色](#项目特色)
|
- [项目特色](#项目特色)
|
||||||
@ -73,6 +80,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||||
|
|
||||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||||
|
|
||||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||||
@ -163,35 +172,38 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 模型大小 | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
|
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||||
|
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||||
|
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
@ -412,7 +424,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
|
|
||||||
### 数据准备
|
### 数据准备
|
||||||
|
|
||||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||||
|
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||||
- `data`:宿主机中存放数据集的文件夹路径。
|
- `data`:宿主机中存放数据集的文件夹路径。
|
||||||
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||||
|
|
||||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||||
|
>
|
||||||
|
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### 从魔搭社区下载
|
### 从魔搭社区下载
|
||||||
|
|
||||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
|
|
||||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||||
|
|
||||||
|
### 从魔乐社区下载
|
||||||
|
|
||||||
|
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||||
|
```
|
||||||
|
|
||||||
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
@ -684,11 +712,12 @@ run_name: test_run # 可选
|
|||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -696,7 +725,7 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
1630
assets/benchmark.svg
1630
assets/benchmark.svg
File diff suppressed because it is too large
Load Diff
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 28 KiB |
Binary file not shown.
Before Width: | Height: | Size: 145 KiB After Width: | Height: | Size: 132 KiB |
Binary file not shown.
Before Width: | Height: | Size: 149 KiB After Width: | Height: | Size: 132 KiB |
@ -4999,4 +4999,4 @@
|
|||||||
"input": "Time waits for no one.",
|
"input": "Time waits for no one.",
|
||||||
"output": "No one can stop time from moving forward."
|
"output": "No one can stop time from moving forward."
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -4999,4 +4999,4 @@
|
|||||||
"input": "",
|
"input": "",
|
||||||
"output": "安第斯山脉位于南美洲,横跨七个国家,包括委内瑞拉,哥伦比亚,厄瓜多尔,秘鲁,玻利维亚,智利和阿根廷。安第斯山脉是世界上最长的山脉之一,全长约7,000千米(4,350英里),其山脉沿着南美洲西海岸蜿蜒延伸,平均海拔约为4,000米(13,000英尺)。在其南部,安第斯山脉宽度达到700千米(430英里),在其北部宽度约为500千米(310英里)。"
|
"output": "安第斯山脉位于南美洲,横跨七个国家,包括委内瑞拉,哥伦比亚,厄瓜多尔,秘鲁,玻利维亚,智利和阿根廷。安第斯山脉是世界上最长的山脉之一,全长约7,000千米(4,350英里),其山脉沿着南美洲西海岸蜿蜒延伸,平均海拔约为4,000米(13,000英尺)。在其南部,安第斯山脉宽度达到700千米(430英里),在其北部宽度约为500千米(310英里)。"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -17,9 +17,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str):
|
def _generate_examples(self, filepath: str):
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
conversations = []
|
conversations = []
|
||||||
|
File diff suppressed because one or more lines are too long
@ -54,7 +54,8 @@
|
|||||||
},
|
},
|
||||||
"alpaca_en": {
|
"alpaca_en": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_en",
|
"hf_hub_url": "llamafactory/alpaca_en",
|
||||||
"ms_hub_url": "llamafactory/alpaca_en"
|
"ms_hub_url": "llamafactory/alpaca_en",
|
||||||
|
"om_hub_url": "HaM/alpaca_en"
|
||||||
},
|
},
|
||||||
"alpaca_zh": {
|
"alpaca_zh": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_zh",
|
"hf_hub_url": "llamafactory/alpaca_zh",
|
||||||
@ -66,7 +67,8 @@
|
|||||||
},
|
},
|
||||||
"alpaca_gpt4_zh": {
|
"alpaca_gpt4_zh": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||||
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
|
"ms_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||||
|
"om_hub_url": "State_Cloud/alpaca-gpt4-data-zh"
|
||||||
},
|
},
|
||||||
"glaive_toolcall_en": {
|
"glaive_toolcall_en": {
|
||||||
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
||||||
@ -623,4 +625,4 @@
|
|||||||
},
|
},
|
||||||
"folder": "python"
|
"folder": "python"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7223,4 +7223,4 @@
|
|||||||
"value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write."
|
"value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -5055,4 +5055,4 @@
|
|||||||
"value": "C. 参与讨论"
|
"value": "C. 参与讨论"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -9155,4 +9155,4 @@
|
|||||||
],
|
],
|
||||||
"tools": "[]"
|
"tools": "[]"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -9019,4 +9019,4 @@
|
|||||||
],
|
],
|
||||||
"tools": "[]"
|
"tools": "[]"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -8,9 +8,9 @@ import datasets
|
|||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
key = 0
|
key = 0
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
chosen = data["chosen"]
|
chosen = data["chosen"]
|
||||||
|
@ -454,4 +454,4 @@
|
|||||||
"input": "",
|
"input": "",
|
||||||
"output": "抱歉,我不是 OpenAI 开发的 ChatGPT,我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。"
|
"output": "抱歉,我不是 OpenAI 开发的 ChatGPT,我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -5395,4 +5395,4 @@
|
|||||||
],
|
],
|
||||||
"label": false
|
"label": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -137,4 +137,4 @@
|
|||||||
"mllm_demo_data/3.jpg"
|
"mllm_demo_data/3.jpg"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -44,4 +44,4 @@
|
|||||||
"mllm_demo_data/3.mp4"
|
"mllm_demo_data/3.mp4"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -20,9 +20,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||||||
|
|
||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
|||||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||||
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Define environments
|
# Define environments
|
||||||
ENV MAX_JOBS=4
|
ENV MAX_JOBS=4
|
||||||
@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
|
ARG INSTALL_EETQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||||
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
|
@ -8,11 +8,15 @@ services:
|
|||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: false
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: false
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
|
INSTALL_EETQ: false
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
ports:
|
ports:
|
||||||
|
@ -10,6 +10,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- /usr/local/dcmi:/usr/local/dcmi
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
|
@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
|
@ -8,11 +8,14 @@ services:
|
|||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: false
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: false
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- ../../saves:/app/saves
|
- ../../saves:/app/saves
|
||||||
|
@ -207,4 +207,4 @@
|
|||||||
"name": "兽医学",
|
"name": "兽医学",
|
||||||
"category": "STEM"
|
"category": "STEM"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -267,4 +267,4 @@
|
|||||||
"name": "世界宗教",
|
"name": "世界宗教",
|
||||||
"category": "Humanities"
|
"category": "Humanities"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,4 +227,4 @@
|
|||||||
"name": "world religions",
|
"name": "world religions",
|
||||||
"category": "Humanities"
|
"category": "Humanities"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||||||
df = pd.read_csv(filepath, header=None)
|
df = pd.read_csv(filepath, header=None)
|
||||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||||
|
|
||||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
yield from enumerate(df.to_dict(orient="records"))
|
||||||
yield i, instance
|
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
@ -25,4 +25,4 @@
|
|||||||
"contiguous_gradients": true,
|
"contiguous_gradients": true,
|
||||||
"round_robin_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,4 +25,4 @@
|
|||||||
"contiguous_gradients": true,
|
"contiguous_gradients": true,
|
||||||
"round_robin_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,4 +29,4 @@
|
|||||||
"contiguous_gradients": true,
|
"contiguous_gradients": true,
|
||||||
"round_robin_gradients": true
|
"round_robin_gradients": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,4 +27,4 @@
|
|||||||
"stage3_max_reuse_distance": 1e9,
|
"stage3_max_reuse_distance": 1e9,
|
||||||
"stage3_gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,4 +35,4 @@
|
|||||||
"stage3_max_reuse_distance": 1e9,
|
"stage3_max_reuse_distance": 1e9,
|
||||||
"stage3_gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
datasets>=2.16.0,<=2.21.0
|
datasets>=2.16.0,<=3.0.2
|
||||||
accelerate>=0.30.1,<=0.34.2
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0,<5.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
@ -19,3 +19,4 @@ fire
|
|||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
numpy<2.0.0
|
numpy<2.0.0
|
||||||
|
av
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the imoneoi's OpenChat library.
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
@ -74,7 +73,7 @@ def calculate_lr(
|
|||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
|||||||
elif "4090" in device_name:
|
elif "4090" in device_name:
|
||||||
return 98 * 1e12 * world_size
|
return 98 * 1e12 * world_size
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||||
|
|
||||||
|
|
||||||
def calculate_mfu(
|
def calculate_mfu(
|
||||||
@ -140,10 +139,10 @@ def calculate_mfu(
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
if deepspeed_stage in [2, 3]:
|
if deepspeed_stage in [2, 3]:
|
||||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||||
|
|
||||||
run_exp(args)
|
run_exp(args)
|
||||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||||
result = json.load(f)
|
result = json.load(f)
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
@ -157,7 +156,7 @@ def calculate_mfu(
|
|||||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
/ compute_device_flops(world_size)
|
/ compute_device_flops(world_size)
|
||||||
)
|
)
|
||||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -100,7 +99,7 @@ def calculate_ppl(
|
|||||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
@ -125,8 +124,8 @@ def calculate_ppl(
|
|||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||||
print("Perplexities have been saved at {}.".format(save_name))
|
print(f"Perplexities have been saved at {save_name}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -61,7 +60,7 @@ def length_cdf(
|
|||||||
for length, count in length_tuples:
|
for length, count in length_tuples:
|
||||||
count_accu += count
|
count_accu += count
|
||||||
prob_accu += count / total_num * 100
|
prob_accu += count / total_num * 100
|
||||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
@ -40,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||||
|
|
||||||
|
|
||||||
def block_expansion(
|
def block_expansion(
|
||||||
@ -76,27 +75,27 @@ def block_expansion(
|
|||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
if num_layers % num_expand != 0:
|
if num_layers % num_expand != 0:
|
||||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||||
|
|
||||||
split = num_layers // num_expand
|
split = num_layers // num_expand
|
||||||
layer_cnt = 0
|
layer_cnt = 0
|
||||||
output_state_dict = OrderedDict()
|
output_state_dict = OrderedDict()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||||
|
|
||||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
if (i + 1) % split == 0:
|
if (i + 1) % split == 0:
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
if "down_proj" in key or "o_proj" in key:
|
if "down_proj" in key or "o_proj" in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||||
else:
|
else:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||||
|
|
||||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
@ -113,17 +112,17 @@ def block_expansion(
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print(f"freeze_trainable_layers: {num_expand}")
|
||||||
print("use_llama_pro: true")
|
print("use_llama_pro: true")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str):
|
def save_config(input_dir: str, output_dir: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = value
|
llama2_state_dict[key] = value
|
||||||
else:
|
else:
|
||||||
raise KeyError("Unable to process key {}".format(key))
|
raise KeyError(f"Unable to process key {key}")
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||||
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@ -70,19 +69,19 @@ def quantize_loftq(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(loftq_dir))
|
print(f"Adapter weights saved in {loftq_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
print(f"adapter_name_or_path: {loftq_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("quantization_bit: {}".format(loftq_bits))
|
print(f"quantization_bit: {loftq_bits}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@ -54,7 +53,7 @@ def quantize_pissa(
|
|||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=lora_target,
|
target_modules=lora_target,
|
||||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init PiSSA model
|
# Init PiSSA model
|
||||||
@ -65,17 +64,17 @@ def quantize_pissa(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(pissa_dir))
|
print(f"Adapter weights saved in {pissa_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
print(f"adapter_name_or_path: {pissa_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("pissa_init: false")
|
print("pissa_init: false")
|
||||||
print("pissa_convert: true")
|
print("pissa_convert: true")
|
||||||
|
65
scripts/test_image.py
Normal file
65
scripts/test_image.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||||
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Output the color and number of each box."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 1:", result.choices[0].message.content)
|
||||||
|
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What kind of flower is this?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 2:", result.choices[0].message.content)
|
||||||
|
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
11
setup.py
11
setup.py
@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
|
|
||||||
def get_version() -> str:
|
def get_version() -> str:
|
||||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||||
(version,) = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
@ -28,7 +28,7 @@ def get_version() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_requires() -> List[str]:
|
def get_requires() -> List[str]:
|
||||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
with open("requirements.txt", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||||
return lines
|
return lines
|
||||||
@ -54,13 +54,14 @@ extra_require = {
|
|||||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
"vllm": ["vllm>=0.4.3,<=0.6.0"],
|
"vllm": ["vllm>=0.4.3,<=0.6.3"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
"adam-mini": ["adam-mini"],
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"openmind": ["openmind"],
|
||||||
|
"dev": ["pre-commit", "ruff", "pytest"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ def main():
|
|||||||
author="hiyouga",
|
author="hiyouga",
|
||||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||||
description="Easy-to-use LLM fine-tuning framework",
|
description="Easy-to-use LLM fine-tuning framework",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||||
license="Apache 2.0 License",
|
license="Apache 2.0 License",
|
||||||
|
@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
datasets>=2.16.0,<=2.21.0
|
datasets>=2.16.0,<=3.0.2
|
||||||
accelerate>=0.30.1,<=0.34.2
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
attention:
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
|
|||||||
Force using torchrun: FORCE_TORCHRUN=1
|
Force using torchrun: FORCE_TORCHRUN=1
|
||||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||||
|
Use openmind: USE_OPENMIND_HUB=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .extras.env import VERSION
|
from .extras.env import VERSION
|
||||||
|
@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
|||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY", None)
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
def run_api() -> None:
|
def run_api() -> None:
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
@ -21,7 +21,7 @@ import uuid
|
|||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
ROLE_MAPPING = {
|
ROLE_MAPPING = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
@ -69,8 +69,8 @@ ROLE_MAPPING = {
|
|||||||
|
|
||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
@ -84,7 +84,7 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
image = None
|
images = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
@ -111,7 +111,7 @@ def _process_request(
|
|||||||
else: # web uri
|
else: # web uri
|
||||||
image_stream = requests.get(image_url, stream=True).raw
|
image_stream = requests.get(image_url, stream=True).raw
|
||||||
|
|
||||||
image = Image.open(image_stream).convert("RGB")
|
images.append(Image.open(image_stream).convert("RGB"))
|
||||||
else:
|
else:
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ def _process_request(
|
|||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
|
|
||||||
return input_messages, system, tools, image
|
return input_messages, system, tools, images or None
|
||||||
|
|
||||||
|
|
||||||
def _create_stream_chat_completion_chunk(
|
def _create_stream_chat_completion_chunk(
|
||||||
@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
|
|||||||
async def create_chat_completion_response(
|
async def create_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> "ChatCompletionResponse":
|
) -> "ChatCompletionResponse":
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
responses = await chat_model.achat(
|
responses = await chat_model.achat(
|
||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@ -169,7 +169,7 @@ async def create_chat_completion_response(
|
|||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
function = Function(name=tool[0], arguments=tool[1])
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||||
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
@ -193,8 +193,8 @@ async def create_chat_completion_response(
|
|||||||
async def create_stream_chat_completion_response(
|
async def create_stream_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
|
|||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
|
|||||||
async def create_score_evaluation_response(
|
async def create_score_evaluation_response(
|
||||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||||
) -> "ScoreEvaluationResponse":
|
) -> "ScoreEvaluationResponse":
|
||||||
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
@ -66,8 +66,8 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
@ -81,8 +81,8 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -53,7 +53,7 @@ class ChatModel:
|
|||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == "vllm":
|
||||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
@ -64,15 +64,15 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Gets a list of responses of the chat model.
|
Gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(
|
task = asyncio.run_coroutine_threadsafe(
|
||||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||||
)
|
)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
@ -81,28 +81,28 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets a list of responses of the chat model.
|
Asynchronously gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
r"""
|
r"""
|
||||||
Gets the response token-by-token of the chat model.
|
Gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
@ -115,14 +115,14 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets the response token-by-token of the chat model.
|
Asynchronously gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||||
yield new_token
|
yield new_token
|
||||||
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
|
@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceEngine(BaseEngine):
|
class HuggingfaceEngine(BaseEngine):
|
||||||
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("There is no current event loop, creating a new one.")
|
logger.warning_once("There is no current event loop, creating a new one.")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||||
if image is not None:
|
if images is not None:
|
||||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
if video is not None:
|
if videos is not None:
|
||||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||||
|
|
||||||
messages = template.mm_plugin.process_messages(
|
messages = template.mm_plugin.process_messages(
|
||||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||||
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||||
|
|
||||||
generating_args = generating_args.copy()
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif not isinstance(value, torch.Tensor):
|
||||||
|
value = torch.tensor(value)
|
||||||
|
|
||||||
gen_kwargs[key] = value.to(model.device)
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
|
@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VllmEngine(BaseEngine):
|
class VllmEngine(BaseEngine):
|
||||||
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
|
|||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
import vllm.model_executor.models.llava
|
import vllm.model_executor.models.llava
|
||||||
|
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
|
|
||||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||||
@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
if image is not None:
|
if images is not None:
|
||||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
system = system or self.generating_args["default_system"]
|
||||||
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
|
|||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image is not None: # add image features
|
if images is not None: # add image features
|
||||||
if not isinstance(image, (str, ImageObject)):
|
image_data = []
|
||||||
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
for image in images:
|
||||||
|
if not isinstance(image, (str, ImageObject)):
|
||||||
|
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||||
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
multi_modal_data = {"image": image}
|
image_data.append(image)
|
||||||
|
|
||||||
|
multi_modal_data = {"image": image_data}
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
final_output = None
|
final_output = None
|
||||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for request_output in generator:
|
async for request_output in generator:
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
|
||||||
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for result in generator:
|
async for result in generator:
|
||||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
generated_text = result.outputs[0].text
|
generated_text = result.outputs[0].text
|
||||||
|
@ -22,8 +22,8 @@ from . import launcher
|
|||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
from .chat.chat_model import run_chat
|
from .chat.chat_model import run_chat
|
||||||
from .eval.evaluator import run_eval
|
from .eval.evaluator import run_eval
|
||||||
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.logging import get_logger
|
|
||||||
from .extras.misc import get_device_count
|
from .extras.misc import get_device_count
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
@ -47,7 +47,7 @@ USAGE = (
|
|||||||
WELCOME = (
|
WELCOME = (
|
||||||
"-" * 58
|
"-" * 58
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||||
+ " " * (21 - len(VERSION))
|
+ " " * (21 - len(VERSION))
|
||||||
+ "|\n|"
|
+ "|\n|"
|
||||||
+ " " * 56
|
+ " " * 56
|
||||||
@ -56,7 +56,7 @@ WELCOME = (
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -86,19 +86,19 @@ def main():
|
|||||||
elif command == Command.EXPORT:
|
elif command == Command.EXPORT:
|
||||||
export_model()
|
export_model()
|
||||||
elif command == Command.TRAIN:
|
elif command == Command.TRAIN:
|
||||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||||
if force_torchrun or get_device_count() > 1:
|
if force_torchrun or get_device_count() > 1:
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||||
process = subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
).format(
|
).format(
|
||||||
nnodes=os.environ.get("NNODES", "1"),
|
nnodes=os.getenv("NNODES", "1"),
|
||||||
node_rank=os.environ.get("RANK", "0"),
|
node_rank=os.getenv("NODE_RANK", "0"),
|
||||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
file_name=launcher.__file__,
|
file_name=launcher.__file__,
|
||||||
@ -118,4 +118,4 @@ def main():
|
|||||||
elif command == Command.HELP:
|
elif command == Command.HELP:
|
||||||
print(USAGE)
|
print(USAGE)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
raise NotImplementedError(f"Unknown command: {command}.")
|
||||||
|
@ -16,7 +16,7 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
|
|
||||||
|
|
||||||
@ -29,45 +29,51 @@ if TYPE_CHECKING:
|
|||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_images(
|
def _convert_images(
|
||||||
images: Sequence["ImageInput"],
|
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Optional[List["ImageInput"]]:
|
) -> Optional[List["ImageInput"]]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
if len(images) == 0:
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
elif len(images) == 0:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
images = images[:]
|
||||||
|
|
||||||
images = images[:]
|
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
for i in range(len(images)):
|
for i in range(len(images)):
|
||||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||||
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def _convert_videos(
|
def _convert_videos(
|
||||||
videos: Sequence["VideoInput"],
|
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Optional[List["VideoInput"]]:
|
) -> Optional[List["VideoInput"]]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
if len(videos) == 0:
|
if not isinstance(videos, list):
|
||||||
|
videos = [videos]
|
||||||
|
elif len(videos) == 0:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
videos = videos[:]
|
||||||
|
|
||||||
videos = videos[:]
|
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
for i in range(len(videos)):
|
for i in range(len(videos)):
|
||||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||||
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@ -161,7 +167,7 @@ def convert_sharegpt(
|
|||||||
broken_data = False
|
broken_data = False
|
||||||
for turn_idx, message in enumerate(messages):
|
for turn_idx, message in enumerate(messages):
|
||||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
logger.warning("Invalid role tag in {}.".format(messages))
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
aligned_messages.append(
|
aligned_messages.append(
|
||||||
@ -171,7 +177,7 @@ def convert_sharegpt(
|
|||||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
):
|
):
|
||||||
logger.warning("Invalid message count in {}.".format(messages))
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||||
@ -192,7 +198,7 @@ def convert_sharegpt(
|
|||||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
):
|
):
|
||||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
prompt = aligned_messages
|
prompt = aligned_messages
|
||||||
@ -205,7 +211,7 @@ def convert_sharegpt(
|
|||||||
response = aligned_messages[-1:]
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
if broken_data:
|
if broken_data:
|
||||||
logger.warning("Skipping this abnormal example.")
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
prompt, response = [], []
|
prompt, response = [], []
|
||||||
|
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
features.update(mm_inputs)
|
features.update(mm_inputs)
|
||||||
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
|||||||
for key in ("chosen", "rejected"):
|
for key in ("chosen", "rejected"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
target_feature = {
|
target_feature = {
|
||||||
"input_ids": feature["{}_input_ids".format(key)],
|
"input_ids": feature[f"{key}_input_ids"],
|
||||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
"attention_mask": feature[f"{key}_attention_mask"],
|
||||||
"labels": feature["{}_labels".format(key)],
|
"labels": feature[f"{key}_labels"],
|
||||||
"images": feature["images"],
|
"images": feature["images"],
|
||||||
"videos": feature["videos"],
|
"videos": feature["videos"],
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
|||||||
|
|
||||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
@ -56,12 +56,12 @@ def merge_dataset(
|
|||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
|
||||||
return concatenate_datasets(all_datasets)
|
return concatenate_datasets(all_datasets)
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
@ -70,7 +70,7 @@ def merge_dataset(
|
|||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
|
@ -83,14 +83,14 @@ class StringFormatter(Formatter):
|
|||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
for name, value in kwargs.items():
|
for name, value in kwargs.items():
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise RuntimeError("Expected a string, got {}".format(value))
|
raise RuntimeError(f"Expected a string, got {value}")
|
||||||
|
|
||||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
|
|||||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
functions = []
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for name, arguments in functions:
|
for name, arguments in functions:
|
||||||
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
|
|||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
|
|||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return [""]
|
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
@ -20,8 +20,8 @@ import numpy as np
|
|||||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import has_tokenized_data
|
from ..extras.misc import has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset, split_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_single_dataset(
|
def _load_single_dataset(
|
||||||
@ -51,9 +51,9 @@ def _load_single_dataset(
|
|||||||
r"""
|
r"""
|
||||||
Loads a single dataset and aligns it to the standard format.
|
Loads a single dataset and aligns it to the standard format.
|
||||||
"""
|
"""
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_dir = dataset_attr.folder
|
data_dir = dataset_attr.folder
|
||||||
@ -69,25 +69,24 @@ def _load_single_dataset(
|
|||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
|
||||||
raise ValueError("File types should be identical.")
|
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("File {} not found.".format(local_path))
|
raise ValueError(f"File {local_path} not found.")
|
||||||
|
|
||||||
|
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||||
|
|
||||||
|
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||||
|
raise ValueError("File types should be identical.")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset # type: ignore
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
@ -98,10 +97,27 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
|
|
||||||
|
elif dataset_attr.load_from == "om_hub":
|
||||||
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
|
from openmind import OmDataset # type: ignore
|
||||||
|
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
|
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||||
|
dataset = OmDataset.load_dataset(
|
||||||
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=dataset_attr.split,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
token=model_args.om_hub_token,
|
||||||
|
streaming=data_args.streaming,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
@ -111,13 +127,10 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=data_args.streaming,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
|
||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
target_num = dataset_attr.num_samples
|
target_num = dataset_attr.num_samples
|
||||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||||
@ -128,7 +141,7 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||||
dataset = dataset.select(indexes)
|
dataset = dataset.select(indexes)
|
||||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
@ -224,9 +237,9 @@ def get_dataset(
|
|||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||||
|
|
||||||
dataset_module: Dict[str, "Dataset"] = {}
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
if "train" in dataset_dict:
|
if "train" in dataset_dict:
|
||||||
@ -277,8 +290,8 @@ def get_dataset(
|
|||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -110,7 +111,7 @@ class BasePlugin:
|
|||||||
image = Image.open(image["path"])
|
image = Image.open(image["path"])
|
||||||
|
|
||||||
if not isinstance(image, ImageObject):
|
if not isinstance(image, ImageObject):
|
||||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||||
|
|
||||||
results.append(self._preprocess_image(image, **kwargs))
|
results.append(self._preprocess_image(image, **kwargs))
|
||||||
|
|
||||||
@ -157,6 +158,7 @@ class BasePlugin:
|
|||||||
It holds num_patches == torch.prod(image_grid_thw)
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||||
input_dict = {"images": None} # default key
|
input_dict = {"images": None} # default key
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@ -174,10 +176,16 @@ class BasePlugin:
|
|||||||
)
|
)
|
||||||
input_dict["videos"] = videos
|
input_dict["videos"] = videos
|
||||||
|
|
||||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
mm_inputs = {}
|
||||||
return image_processor(**input_dict, return_tensors="pt")
|
if image_processor != video_processor:
|
||||||
else:
|
if input_dict.get("images") is not None:
|
||||||
return {}
|
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||||
|
if input_dict.get("videos") is not None:
|
||||||
|
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||||
|
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||||
|
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
@ -218,6 +226,14 @@ class BasePlugin:
|
|||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
Builds batched multimodal inputs for VLMs.
|
Builds batched multimodal inputs for VLMs.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
images: a list of image inputs, shape (num_images,)
|
||||||
|
videos: a list of video inputs, shape (num_videos,)
|
||||||
|
imglens: number of images in each sample, shape (batch_size,)
|
||||||
|
vidlens: number of videos in each sample, shape (batch_size,)
|
||||||
|
seqlens: number of tokens in each sample, shape (batch_size,)
|
||||||
|
processor: a processor for pre-processing images and videos
|
||||||
"""
|
"""
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return {}
|
return {}
|
||||||
@ -245,7 +261,123 @@ class LlavaPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "image_sizes" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
res = self._get_mm_inputs(images, videos, processor)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if "pixel_values_videos" in mm_inputs:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.video_token in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -284,7 +416,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", "")
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -324,6 +456,68 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
patch_size = getattr(processor, "patch_size")
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
image_break_token = getattr(processor, "image_break_token")
|
||||||
|
image_end_token = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if image_input_sizes is None:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // patch_size
|
||||||
|
num_width_tokens = width // patch_size
|
||||||
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if mm_inputs.get("pixel_values"):
|
||||||
|
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||||
|
|
||||||
|
mm_inputs.pop("image_sizes", None)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
@ -369,7 +563,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(image_grid_thw):
|
if num_image_tokens >= len(image_grid_thw):
|
||||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
@ -382,7 +576,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(video_grid_thw):
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER,
|
VIDEO_PLACEHOLDER,
|
||||||
@ -396,10 +590,73 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
if len(videos) != num_video_tokens:
|
||||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLlavaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
num_frames = 0
|
||||||
|
exist_images = "pixel_values_images" in mm_inputs
|
||||||
|
exist_videos = "pixel_values_videos" in mm_inputs
|
||||||
|
if exist_videos or exist_images:
|
||||||
|
if exist_images:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
|
num_frames = 1
|
||||||
|
if exist_videos:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
|
video_seqlen = image_seqlen * num_frames
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
|
while self.video_token in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
|
||||||
|
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -420,8 +677,12 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
|
"llava_next": LlavaNextPlugin,
|
||||||
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
|
"video_llava": VideoLlavaPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -432,6 +693,6 @@ def get_mm_plugin(
|
|||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
plugin_class = PLUGINS.get(name, None)
|
plugin_class = PLUGINS.get(name, None)
|
||||||
if plugin_class is None:
|
if plugin_class is None:
|
||||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||||
|
|
||||||
return plugin_class(image_token, video_token)
|
return plugin_class(image_token, video_token)
|
||||||
|
@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
|
|||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,7 +30,7 @@ class DatasetAttr:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# basic configs
|
# basic configs
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
|||||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r") as f:
|
with open(config_path) as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||||
|
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
dataset_list: List["DatasetAttr"] = []
|
dataset_list: List["DatasetAttr"] = []
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
if dataset_info is None: # dataset_dir is ONLINE
|
if dataset_info is None: # dataset_dir is ONLINE
|
||||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
if use_modelscope():
|
||||||
|
load_from = "ms_hub"
|
||||||
|
elif use_openmind():
|
||||||
|
load_from = "om_hub"
|
||||||
|
else:
|
||||||
|
load_from = "hf_hub"
|
||||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||||
|
|
||||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
has_om_url = "om_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
if has_hf_url or has_ms_url or has_om_url:
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
if has_ms_url and (use_modelscope() or not has_hf_url):
|
||||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
elif has_om_url and (use_openmind() or not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
elif "script_url" in dataset_info[name]:
|
elif "script_url" in dataset_info[name]:
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_feedback_example(
|
def _encode_feedback_example(
|
||||||
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||||
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
|||||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||||
if desirable_num == 0 or undesirable_num == 0:
|
if desirable_num == 0 or undesirable_num == 0:
|
||||||
logger.warning("Your dataset only has one preference type.")
|
logger.warning_rank0("Your dataset only has one preference type.")
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_pairwise_example(
|
def _encode_pairwise_example(
|
||||||
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||||
@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
|
|||||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_supervised_example(
|
def _encode_supervised_example(
|
||||||
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
|||||||
length2indexes = defaultdict(list)
|
length2indexes = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
)
|
)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
if length > data_args.cutoff_len:
|
if length > data_args.cutoff_len:
|
||||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||||
else:
|
else:
|
||||||
lengths.append(length)
|
lengths.append(length)
|
||||||
length2indexes[length].append(valid_num)
|
length2indexes[length].append(valid_num)
|
||||||
@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
|||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ..data_utils import Role
|
from ..data_utils import Role
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_unsupervised_example(
|
def _encode_unsupervised_example(
|
||||||
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1:
|
if len(examples["_prompt"][i]) % 2 != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_unsupervised_example(
|
input_ids, labels = _encode_unsupervised_example(
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
from .mm_plugin import get_mm_plugin
|
from .mm_plugin import get_mm_plugin
|
||||||
@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
|||||||
from .mm_plugin import BasePlugin
|
from .mm_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -49,6 +49,7 @@ class Template:
|
|||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
|
replace_jinja_template: bool
|
||||||
mm_plugin: "BasePlugin"
|
mm_plugin: "BasePlugin"
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
@ -146,7 +147,7 @@ class Template:
|
|||||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||||
token_ids += [tokenizer.eos_token_id]
|
token_ids += [tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@ -214,6 +215,7 @@ def _register_template(
|
|||||||
stop_words: Sequence[str] = [],
|
stop_words: Sequence[str] = [],
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
|
replace_jinja_template: bool = True,
|
||||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
@ -263,6 +265,7 @@ def _register_template(
|
|||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
|
replace_jinja_template=replace_jinja_template,
|
||||||
mm_plugin=mm_plugin,
|
mm_plugin=mm_plugin,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
|||||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||||
|
|
||||||
if is_added:
|
if is_added:
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||||
else:
|
else:
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||||
|
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
|
||||||
def _jinja_escape(content: str) -> str:
|
def _jinja_escape(content: str) -> str:
|
||||||
@ -353,24 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
r"""
|
r"""
|
||||||
Gets chat template and fixes the tokenizer.
|
Gets chat template and fixes the tokenizer.
|
||||||
"""
|
"""
|
||||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
|
||||||
require_version(
|
|
||||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
|
||||||
)
|
|
||||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
template = TEMPLATES["empty"] # placeholder
|
||||||
else:
|
else:
|
||||||
template = TEMPLATES.get(data_args.template, None)
|
template = TEMPLATES.get(data_args.template, None)
|
||||||
if template is None:
|
if template is None:
|
||||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||||
|
|
||||||
|
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
if data_args.tool_format is not None:
|
if data_args.tool_format is not None:
|
||||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
@ -388,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
try:
|
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
try:
|
||||||
except ValueError:
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||||
logger.info("Cannot add this chat template to tokenizer.")
|
except ValueError as e:
|
||||||
|
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@ -640,6 +641,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="exaone",
|
||||||
|
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||||
|
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="falcon",
|
name="falcon",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||||
@ -664,6 +673,7 @@ _register_template(
|
|||||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -681,6 +691,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="index",
|
||||||
|
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||||
|
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||||
|
efficient_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
@ -740,6 +758,7 @@ _register_template(
|
|||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -754,6 +773,107 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_llama3",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|eot_id|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_mistral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_qwen",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_yi",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video_mistral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video_yi",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
@ -831,6 +951,14 @@ _register_template(
|
|||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="pixtral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
@ -840,6 +968,7 @@ _register_template(
|
|||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -852,6 +981,7 @@ _register_template(
|
|||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -907,6 +1037,17 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="video_llava",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="xuanyuan",
|
name="xuanyuan",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||||
|
@ -177,6 +177,6 @@ TOOLS = {
|
|||||||
def get_tool_utils(name: str) -> "ToolUtils":
|
def get_tool_utils(name: str) -> "ToolUtils":
|
||||||
tool_utils = TOOLS.get(name, None)
|
tool_utils = TOOLS.get(name, None)
|
||||||
if tool_utils is None:
|
if tool_utils is None:
|
||||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
raise ValueError(f"Tool utils `{name}` not found.")
|
||||||
|
|
||||||
return tool_utils
|
return tool_utils
|
||||||
|
@ -87,7 +87,7 @@ class Evaluator:
|
|||||||
token=self.model_args.hf_hub_token,
|
token=self.model_args.hf_hub_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, encoding="utf-8") as f:
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||||
@ -139,7 +139,7 @@ class Evaluator:
|
|||||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join(
|
score_info = "\n".join(
|
||||||
[
|
[
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||||
for category_name, category_correct in category_corrects.items()
|
for category_name, category_correct in category_corrects.items()
|
||||||
if len(category_correct)
|
if len(category_correct)
|
||||||
]
|
]
|
||||||
|
@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
|
|||||||
|
|
||||||
def get_eval_template(name: str) -> "EvalTemplate":
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
eval_template = eval_templates.get(name, None)
|
eval_template = eval_templates.get(name, None)
|
||||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
assert eval_template is not None, f"Template {name} does not exist."
|
||||||
return eval_template
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
|
|||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
IMAGE_PLACEHOLDER = "<image>"
|
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
|||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = "<video>"
|
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
@ -107,6 +108,7 @@ VISION_MODELS = set()
|
|||||||
class DownloadSource(str, Enum):
|
class DownloadSource(str, Enum):
|
||||||
DEFAULT = "hf"
|
DEFAULT = "hf"
|
||||||
MODELSCOPE = "ms"
|
MODELSCOPE = "ms"
|
||||||
|
OPENMIND = "om"
|
||||||
|
|
||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
@ -114,17 +116,12 @@ def register_model_group(
|
|||||||
template: Optional[str] = None,
|
template: Optional[str] = None,
|
||||||
vision: bool = False,
|
vision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
prefix = None
|
|
||||||
for name, path in models.items():
|
for name, path in models.items():
|
||||||
if prefix is None:
|
|
||||||
prefix = name.split("-")[0]
|
|
||||||
else:
|
|
||||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
|
||||||
SUPPORTED_MODELS[name] = path
|
SUPPORTED_MODELS[name] = path
|
||||||
if template is not None:
|
if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
|
||||||
DEFAULT_TEMPLATE[prefix] = template
|
DEFAULT_TEMPLATE[name] = template
|
||||||
if vision:
|
if vision:
|
||||||
VISION_MODELS.add(prefix)
|
VISION_MODELS.add(name)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
@ -168,14 +165,17 @@ register_model_group(
|
|||||||
"Baichuan2-13B-Base": {
|
"Baichuan2-13B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
|
||||||
},
|
},
|
||||||
"Baichuan2-7B-Chat": {
|
"Baichuan2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Chat": {
|
"Baichuan2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="baichuan2",
|
template="baichuan2",
|
||||||
@ -274,27 +274,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-1.3B": {
|
"Chinese-Llama-2-1.3B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B": {
|
"Chinese-Llama-2-7B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B": {
|
"Chinese-Llama-2-13B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-1.3B-Chat": {
|
"Chinese-Alpaca-2-1.3B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B-Chat": {
|
"Chinese-Alpaca-2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B-Chat": {
|
"Chinese-Alpaca-2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||||
},
|
},
|
||||||
@ -450,25 +450,25 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"DeepSeekCoder-6.7B-Base": {
|
"DeepSeek-Coder-6.7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-7B-Base": {
|
"DeepSeek-Coder-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Base": {
|
"DeepSeek-Coder-33B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-6.7B-Instruct": {
|
"DeepSeek-Coder-6.7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-7B-Instruct": {
|
"DeepSeek-Coder-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Instruct": {
|
"DeepSeek-Coder-33B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
},
|
},
|
||||||
@ -477,6 +477,16 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"EXAONE-3.0-7.8B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="exaone",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Falcon-7B": {
|
"Falcon-7B": {
|
||||||
@ -550,10 +560,12 @@ register_model_group(
|
|||||||
"Gemma-2-2B-Instruct": {
|
"Gemma-2-2B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
|
||||||
},
|
},
|
||||||
"Gemma-2-9B-Instruct": {
|
"Gemma-2-9B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
|
||||||
},
|
},
|
||||||
"Gemma-2-27B-Instruct": {
|
"Gemma-2-27B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||||
@ -573,6 +585,7 @@ register_model_group(
|
|||||||
"GLM-4-9B-Chat": {
|
"GLM-4-9B-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
|
||||||
},
|
},
|
||||||
"GLM-4-9B-1M-Chat": {
|
"GLM-4-9B-1M-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
||||||
@ -583,6 +596,33 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Index-1.9B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Character-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Base-Pure": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Chat-32K": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="index",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
@ -624,16 +664,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||||
},
|
},
|
||||||
},
|
|
||||||
template="intern2",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
|
||||||
models={
|
|
||||||
"InternLM2.5-1.8B": {
|
"InternLM2.5-1.8B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B": {
|
"InternLM2.5-7B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
||||||
@ -642,22 +676,27 @@ register_model_group(
|
|||||||
"InternLM2.5-20B": {
|
"InternLM2.5-20B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
|
||||||
},
|
},
|
||||||
"InternLM2.5-1.8B-Chat": {
|
"InternLM2.5-1.8B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B-Chat": {
|
"InternLM2.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B-1M-Chat": {
|
"InternLM2.5-7B-1M-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
|
||||||
},
|
},
|
||||||
"InternLM2.5-20B-Chat": {
|
"InternLM2.5-20B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="intern2",
|
template="intern2",
|
||||||
@ -686,19 +725,19 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA-7B": {
|
"Llama-7B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
||||||
},
|
},
|
||||||
"LLaMA-13B": {
|
"Llama-13B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||||
},
|
},
|
||||||
"LLaMA-30B": {
|
"Llama-30B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||||
},
|
},
|
||||||
"LLaMA-65B": {
|
"Llama-65B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||||
},
|
},
|
||||||
@ -708,27 +747,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA2-7B": {
|
"Llama-2-7B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B": {
|
"Llama-2-13B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B": {
|
"Llama-2-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-7B-Chat": {
|
"Llama-2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B-Chat": {
|
"Llama-2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B-Chat": {
|
"Llama-2-70B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||||
},
|
},
|
||||||
@ -739,60 +778,78 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA3-8B": {
|
"Llama-3-8B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B": {
|
"Llama-3-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||||
},
|
},
|
||||||
"LLaMA3-8B-Instruct": {
|
"Llama-3-8B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B-Instruct": {
|
"Llama-3-70B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-8B-Chinese-Chat": {
|
"Llama-3-8B-Chinese-Chat": {
|
||||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B-Chinese-Chat": {
|
"Llama-3-70B-Chinese-Chat": {
|
||||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
||||||
},
|
},
|
||||||
},
|
"Llama-3.1-8B": {
|
||||||
template="llama3",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
|
||||||
models={
|
|
||||||
"LLaMA3.1-8B": {
|
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-70B": {
|
"Llama-3.1-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-405B": {
|
"Llama-3.1-405B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-8B-Instruct": {
|
"Llama-3.1-8B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-70B-Instruct": {
|
"Llama-3.1-70B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-405B-Instruct": {
|
"Llama-3.1-405B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
|
||||||
},
|
},
|
||||||
|
"Llama-3.1-8B-Chinese-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
|
||||||
|
},
|
||||||
|
"Llama-3.1-70B-Chinese-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
|
||||||
|
},
|
||||||
|
"Llama-3.2-1B": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
|
||||||
|
},
|
||||||
|
"Llama-3.2-3B": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
|
||||||
|
},
|
||||||
|
"Llama-3.2-1B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
|
||||||
|
},
|
||||||
|
"Llama-3.2-3B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="llama3",
|
template="llama3",
|
||||||
)
|
)
|
||||||
@ -800,11 +857,13 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaVA1.5-7B-Chat": {
|
"LLaVA-1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
|
||||||
},
|
},
|
||||||
"LLaVA1.5-13B-Chat": {
|
"LLaVA-1.5-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="llava",
|
template="llava",
|
||||||
@ -812,6 +871,117 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-13B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Mistral-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_mistral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Llama3-8B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_llama3",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-34B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_yi",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-72B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-110B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_qwen",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-Video-7B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-7B-32k-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video_mistral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-34B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-Video-34B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video_yi",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiniCPM-2B-SFT-Chat": {
|
"MiniCPM-2B-SFT-Chat": {
|
||||||
@ -832,6 +1002,7 @@ register_model_group(
|
|||||||
"MiniCPM3-4B-Chat": {
|
"MiniCPM3-4B-Chat": {
|
||||||
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
||||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="cpm3",
|
template="cpm3",
|
||||||
@ -1005,27 +1176,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Phi3-4B-4k-Instruct": {
|
"Phi-3-4B-4k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-4B-128k-Instruct": {
|
"Phi-3-4B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-7B-8k-Instruct": {
|
"Phi-3-7B-8k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-7B-128k-Instruct": {
|
"Phi-3-7B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-14B-8k-Instruct": {
|
"Phi-3-14B-8k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-14B-128k-Instruct": {
|
"Phi-3-14B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||||
},
|
},
|
||||||
@ -1034,6 +1205,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Pixtral-12B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="pixtral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen-1.8B": {
|
"Qwen-1.8B": {
|
||||||
@ -1068,35 +1251,35 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int8-Chat": {
|
"Qwen-1.8B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int4-Chat": {
|
"Qwen-1.8B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int8-Chat": {
|
"Qwen-7B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int4-Chat": {
|
"Qwen-7B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int8-Chat": {
|
"Qwen-14B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int4-Chat": {
|
"Qwen-14B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int8-Chat": {
|
"Qwen-72B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int4-Chat": {
|
"Qwen-72B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||||
},
|
},
|
||||||
@ -1179,75 +1362,75 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen1.5-0.5B-int8-Chat": {
|
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-0.5B-int4-Chat": {
|
"Qwen1.5-0.5B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int8-Chat": {
|
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int4-Chat": {
|
"Qwen1.5-1.8B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int8-Chat": {
|
"Qwen1.5-4B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int4-Chat": {
|
"Qwen1.5-4B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int8-Chat": {
|
"Qwen1.5-7B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int4-Chat": {
|
"Qwen1.5-7B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int8-Chat": {
|
"Qwen1.5-14B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int4-Chat": {
|
"Qwen1.5-14B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-32B-int4-Chat": {
|
"Qwen1.5-32B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int8-Chat": {
|
"Qwen1.5-72B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int4-Chat": {
|
"Qwen1.5-72B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-110B-int4-Chat": {
|
"Qwen1.5-110B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B": {
|
"CodeQwen1.5-7B": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B-Chat": {
|
"CodeQwen1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B-int4-Chat": {
|
"CodeQwen1.5-7B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
@ -1281,14 +1464,17 @@ register_model_group(
|
|||||||
"Qwen2-0.5B-Instruct": {
|
"Qwen2-0.5B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-1.5B-Instruct": {
|
"Qwen2-1.5B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-7B-Instruct": {
|
"Qwen2-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-72B-Instruct": {
|
"Qwen2-72B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
||||||
@ -1568,51 +1754,53 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen2VL-2B-Instruct": {
|
"Qwen2-VL-2B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct": {
|
"Qwen2-VL-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct": {
|
"Qwen2-VL-72B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-AWQ": {
|
"Qwen2-VL-2B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-AWQ": {
|
"Qwen2-VL-7B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-AWQ": {
|
"Qwen2-VL-72B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
@ -1673,10 +1861,12 @@ register_model_group(
|
|||||||
"TeleChat-7B-Chat": {
|
"TeleChat-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||||
|
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
|
||||||
},
|
},
|
||||||
"TeleChat-12B-Chat": {
|
"TeleChat-12B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
||||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
||||||
|
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
|
||||||
},
|
},
|
||||||
"TeleChat-12B-v2-Chat": {
|
"TeleChat-12B-v2-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||||
@ -1689,11 +1879,11 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": {
|
"Vicuna-v1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||||
},
|
},
|
||||||
"Vicuna1.5-13B-Chat": {
|
"Vicuna-v1.5-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||||
},
|
},
|
||||||
@ -1702,6 +1892,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Video-LLaVA-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="video_llava",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XuanYuan-6B": {
|
"XuanYuan-6B": {
|
||||||
@ -1712,7 +1913,7 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B": {
|
"XuanYuan2-70B": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
||||||
},
|
},
|
||||||
@ -1724,31 +1925,31 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-Chat": {
|
"XuanYuan2-70B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||||
},
|
},
|
||||||
"XuanYuan-6B-int8-Chat": {
|
"XuanYuan-6B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-6B-int4-Chat": {
|
"XuanYuan-6B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-70B-int8-Chat": {
|
"XuanYuan-70B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-70B-int4-Chat": {
|
"XuanYuan-70B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-int8-Chat": {
|
"XuanYuan2-70B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-int4-Chat": {
|
"XuanYuan2-70B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||||
},
|
},
|
||||||
@ -1853,19 +2054,19 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
||||||
},
|
},
|
||||||
"Yi-6B-int8-Chat": {
|
"Yi-6B-Chat-8bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||||
},
|
},
|
||||||
"Yi-6B-int4-Chat": {
|
"Yi-6B-Chat-4bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||||
},
|
},
|
||||||
"Yi-34B-int8-Chat": {
|
"Yi-34B-Chat-8bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||||
},
|
},
|
||||||
"Yi-34B-int4-Chat": {
|
"Yi-34B-Chat-4bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||||
},
|
},
|
||||||
@ -1884,6 +2085,7 @@ register_model_group(
|
|||||||
"Yi-1.5-6B-Chat": {
|
"Yi-1.5-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
|
||||||
},
|
},
|
||||||
"Yi-1.5-9B-Chat": {
|
"Yi-1.5-9B-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
||||||
@ -1916,10 +2118,10 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"YiVL-6B-Chat": {
|
"Yi-VL-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
||||||
},
|
},
|
||||||
"YiVL-34B-Chat": {
|
"Yi-VL-34B-Chat": {
|
||||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -72,4 +72,4 @@ def print_env() -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||||
|
@ -20,6 +20,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .constants import RUNNING_LOG
|
from .constants import RUNNING_LOG
|
||||||
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
|
|||||||
|
|
||||||
def __init__(self, output_dir: str) -> None:
|
def __init__(self, output_dir: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
formatter = logging.Formatter(
|
self._formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
self.setLevel(logging.INFO)
|
self.setLevel(logging.INFO)
|
||||||
self.setFormatter(formatter)
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||||
if os.path.exists(self.running_log):
|
if os.path.exists(self.running_log):
|
||||||
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
|
|||||||
if record.name == "httpx":
|
if record.name == "httpx":
|
||||||
return
|
return
|
||||||
|
|
||||||
log_entry = self.format(record)
|
log_entry = self._formatter.format(record)
|
||||||
self.thread_pool.submit(self._write_log, log_entry)
|
self.thread_pool.submit(self._write_log, log_entry)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
|
|||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class _Logger(logging.Logger):
|
||||||
|
r"""
|
||||||
|
A logger that supports info_rank0 and warning_once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def info_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_once(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _get_default_logging_level() -> "logging._Level":
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
r"""
|
r"""
|
||||||
Returns the default logging level.
|
Returns the default logging level.
|
||||||
@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
|
|||||||
if env_level_str.upper() in logging._nameToLevel:
|
if env_level_str.upper() in logging._nameToLevel:
|
||||||
return logging._nameToLevel[env_level_str.upper()]
|
return logging._nameToLevel[env_level_str.upper()]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||||
|
|
||||||
return _default_log_level
|
return _default_log_level
|
||||||
|
|
||||||
@ -84,7 +99,7 @@ def _get_library_name() -> str:
|
|||||||
return __name__.split(".")[0]
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_library_root_logger() -> "logging.Logger":
|
def _get_library_root_logger() -> "_Logger":
|
||||||
return logging.getLogger(_get_library_name())
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
|
|||||||
global _default_handler
|
global _default_handler
|
||||||
|
|
||||||
with _thread_lock:
|
with _thread_lock:
|
||||||
if _default_handler:
|
if _default_handler: # already configured
|
||||||
return
|
return
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
_default_handler = logging.StreamHandler(sys.stdout)
|
_default_handler = logging.StreamHandler(sys.stdout)
|
||||||
_default_handler.setFormatter(formatter)
|
_default_handler.setFormatter(formatter)
|
||||||
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||||
r"""
|
r"""
|
||||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||||
"""
|
"""
|
||||||
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
|||||||
|
|
||||||
_configure_library_root_logger()
|
_configure_library_root_logger()
|
||||||
return logging.getLogger(name)
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: "logging.Handler") -> None:
|
||||||
|
r"""
|
||||||
|
Adds a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
r"""
|
||||||
|
Removes a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(None)
|
||||||
|
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.info_rank0 = info_rank0
|
||||||
|
logging.Logger.warning_rank0 = warning_rank0
|
||||||
|
logging.Logger.warning_once = warning_once
|
||||||
|
@ -32,7 +32,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
@ -76,12 +76,12 @@ def check_dependencies() -> None:
|
|||||||
r"""
|
r"""
|
||||||
Checks the version of the required packages.
|
Checks the version of the required packages.
|
||||||
"""
|
"""
|
||||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
else:
|
else:
|
||||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
||||||
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
|
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||||
|
|
||||||
@ -231,18 +231,35 @@ def torch_gc() -> None:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
|
||||||
return model_args.model_name_or_path
|
return model_args.model_name_or_path
|
||||||
|
|
||||||
try:
|
if use_modelscope():
|
||||||
from modelscope import snapshot_download
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
|
from modelscope import snapshot_download # type: ignore
|
||||||
|
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
return snapshot_download(
|
||||||
except ImportError:
|
model_args.model_name_or_path,
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
revision=revision,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_openmind():
|
||||||
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
|
from openmind.utils.hub import snapshot_download # type: ignore
|
||||||
|
|
||||||
|
return snapshot_download(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_modelscope() -> bool:
|
def use_modelscope() -> bool:
|
||||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||||
|
|
||||||
|
|
||||||
|
def use_openmind() -> bool:
|
||||||
|
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||||
|
@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
|
|||||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_transformers_version_equal_to_4_46():
|
||||||
|
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
from .packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if is_matplotlib_available():
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def smooth(scalars: List[float]) -> List[float]:
|
def smooth(scalars: List[float]) -> List[float]:
|
||||||
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
|||||||
Plots loss curves and saves the image.
|
Plots loss curves and saves the image.
|
||||||
"""
|
"""
|
||||||
plt.switch_backend("agg")
|
plt.switch_backend("agg")
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
|||||||
metrics.append(data["log_history"][i][key])
|
metrics.append(data["log_history"][i][key])
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
logger.warning(f"No metric {key} to plot.")
|
logger.warning_rank0(f"No metric {key} to plot.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||||
plt.title("training {} of {}".format(key, save_dictionary))
|
plt.title(f"training {key} of {save_dictionary}")
|
||||||
plt.xlabel("step")
|
plt.xlabel("step")
|
||||||
plt.ylabel(key)
|
plt.ylabel(key)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
|
@ -41,6 +41,10 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
|
image_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
|
||||||
|
)
|
||||||
cutoff_len: int = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
@ -111,7 +115,13 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
tokenized_path: Optional[str] = field(
|
tokenized_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to save or load the tokenized datasets. "
|
||||||
|
"If tokenized_path not exists, it will save the tokenized datasets. "
|
||||||
|
"If tokenized_path exists, it will load the tokenized datasets."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -123,6 +133,9 @@ class DataArguments:
|
|||||||
self.dataset = split_arg(self.dataset)
|
self.dataset = split_arg(self.dataset)
|
||||||
self.eval_dataset = split_arg(self.eval_dataset)
|
self.eval_dataset = split_arg(self.eval_dataset)
|
||||||
|
|
||||||
|
if self.image_dir is None:
|
||||||
|
self.image_dir = self.dataset_dir
|
||||||
|
|
||||||
if self.dataset is None and self.val_size > 1e-6:
|
if self.dataset is None and self.val_size > 1e-6:
|
||||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -267,6 +267,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
)
|
)
|
||||||
|
om_hub_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||||
|
)
|
||||||
print_param_status: bool = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
@ -308,20 +312,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||||
arg_dict = old_arg.to_dict()
|
init_args, lazy_args = {}, {}
|
||||||
arg_dict.update(**kwargs)
|
for attr in fields(source):
|
||||||
for attr in fields(cls):
|
if attr.init:
|
||||||
if not attr.init:
|
init_args[attr.name] = getattr(source, attr.name)
|
||||||
arg_dict.pop(attr.name)
|
else:
|
||||||
|
lazy_args[attr.name] = getattr(source, attr.name)
|
||||||
|
|
||||||
new_arg = cls(**arg_dict)
|
init_args.update(kwargs)
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
result = cls(**init_args)
|
||||||
new_arg.device_map = old_arg.device_map
|
for name, value in lazy_args.items():
|
||||||
new_arg.model_max_length = old_arg.model_max_length
|
setattr(result, name, value)
|
||||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
|
||||||
return new_arg
|
return result
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES
|
from ..extras.constants import CHECKPOINT_NAMES
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
|||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
|
|
||||||
if unknown_args:
|
if unknown_args:
|
||||||
print(parser.format_help())
|
print(parser.format_help())
|
||||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
return (*parsed_args,)
|
return (*parsed_args,)
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
def _set_transformers_logging() -> None:
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ def _verify_model_args(
|
|||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +122,7 @@ def _check_extra_dependencies(
|
|||||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
|
require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
|
||||||
|
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||||
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.neat_packing and not data_args.packing:
|
if data_args.neat_packing and not data_args.packing:
|
||||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||||
data_args.packing = True
|
data_args.packing = True
|
||||||
|
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and model_args.resize_vocab
|
and model_args.resize_vocab
|
||||||
and finetuning_args.additional_target is None
|
and finetuning_args.additional_target is None
|
||||||
):
|
):
|
||||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
logger.warning_rank0(
|
||||||
|
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
logger.warning_rank0(
|
||||||
|
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||||
|
)
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
can_resume_from_checkpoint = False
|
can_resume_from_checkpoint = False
|
||||||
if training_args.resume_from_checkpoint is not None:
|
if training_args.resume_from_checkpoint is not None:
|
||||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||||
training_args.resume_from_checkpoint = None
|
training_args.resume_from_checkpoint = None
|
||||||
else:
|
else:
|
||||||
can_resume_from_checkpoint = True
|
can_resume_from_checkpoint = True
|
||||||
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args.resume_from_checkpoint = last_checkpoint
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage in ["rm", "ppo"]
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
and training_args.resume_from_checkpoint is not None
|
and training_args.resume_from_checkpoint is not None
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .model_utils.quantization import QuantizationMethod
|
from .model_utils.quantization import QuantizationMethod
|
||||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _setup_full_tuning(
|
def _setup_full_tuning(
|
||||||
@ -45,7 +45,7 @@ def _setup_full_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info_rank0("Fine-tuning method: Full")
|
||||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info_rank0("Fine-tuning method: Freeze")
|
||||||
if hasattr(model.config, "text_config"): # composite models
|
if hasattr(model.config, "text_config"): # composite models
|
||||||
config = getattr(model.config, "text_config")
|
config = getattr(model.config, "text_config")
|
||||||
else:
|
else:
|
||||||
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||||
|
|
||||||
|
|
||||||
def _setup_lora_tuning(
|
def _setup_lora_tuning(
|
||||||
@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
|||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PeftModel":
|
) -> "PeftModel":
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
|||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||||
|
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
|||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
finetuning_args.additional_target = module_names
|
finetuning_args.additional_target = module_names
|
||||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
@ -236,11 +236,11 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
if finetuning_args.pissa_init:
|
if finetuning_args.pissa_init:
|
||||||
if finetuning_args.pissa_iter == -1:
|
if finetuning_args.pissa_iter == -1:
|
||||||
logger.info("Using PiSSA initialization.")
|
logger.info_rank0("Using PiSSA initialization.")
|
||||||
peft_kwargs["init_lora_weights"] = "pissa"
|
peft_kwargs["init_lora_weights"] = "pissa"
|
||||||
else:
|
else:
|
||||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
@ -284,11 +284,11 @@ def init_adapter(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
pass
|
pass
|
||||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info_rank0("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
@ -300,6 +300,6 @@ def init_adapter(
|
|||||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -18,15 +18,15 @@ import torch
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
from .model_utils.misc import register_autoclass
|
from .model_utils.misc import register_autoclass
|
||||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||||
from .model_utils.valuehead import load_valuehead_params
|
from .model_utils.valuehead import load_valuehead_params
|
||||||
from .model_utils.visual import get_image_seqlen
|
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerModule(TypedDict):
|
class TokenizerModule(TypedDict):
|
||||||
@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
Note: including inplace operation of model_args.
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
skip_check_imports()
|
skip_check_imports()
|
||||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
|
||||||
return {
|
return {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
|
|
||||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained tokenizer.
|
Loads pretrained tokenizer and optionally loads processor.
|
||||||
|
|
||||||
Note: including inplace operation of model_args.
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
@ -82,33 +82,30 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise OSError("Failed to load tokenizer.") from e
|
||||||
|
|
||||||
if model_args.new_special_tokens is not None:
|
if model_args.new_special_tokens is not None:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
replace_additional_special_tokens=False,
|
replace_additional_special_tokens=False,
|
||||||
)
|
)
|
||||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||||
model_args.resize_vocab = True
|
model_args.resize_vocab = True
|
||||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
patch_processor(processor, config, tokenizer, model_args)
|
||||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
except Exception as e:
|
||||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
logger.debug(f"Processor was not found: {e}.")
|
||||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
|
||||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
|
||||||
except Exception:
|
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
# Avoid load tokenizer, see:
|
# Avoid load tokenizer, see:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||||
if "Processor" not in processor.__class__.__name__:
|
if processor is not None and "Processor" not in processor.__class__.__name__:
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
return {"tokenizer": tokenizer, "processor": processor}
|
return {"tokenizer": tokenizer, "processor": processor}
|
||||||
@ -135,6 +132,7 @@ def load_model(
|
|||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
@ -157,7 +155,7 @@ def load_model(
|
|||||||
load_class = AutoModelForCausalLM
|
load_class = AutoModelForCausalLM
|
||||||
|
|
||||||
if model_args.train_from_scratch:
|
if model_args.train_from_scratch:
|
||||||
model = load_class.from_config(config)
|
model = load_class.from_config(config, trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
model = load_class.from_pretrained(**init_kwargs)
|
model = load_class.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
@ -182,7 +180,7 @@ def load_model(
|
|||||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
@ -200,9 +198,9 @@ def load_model(
|
|||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param_stats = "all params: {:,}".format(all_param)
|
param_stats = f"all params: {all_param:,}"
|
||||||
|
|
||||||
logger.info(param_stats)
|
logger.info_rank0(param_stats)
|
||||||
|
|
||||||
if model_args.print_param_status:
|
if model_args.print_param_status:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
|
|||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_attn_implementation(
|
def configure_attn_implementation(
|
||||||
@ -37,13 +37,16 @@ def configure_attn_implementation(
|
|||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
if model_args.flash_attn != "fa2":
|
||||||
model_args.flash_attn = "fa2"
|
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
|
model_args.flash_attn = "fa2"
|
||||||
else:
|
else:
|
||||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||||
model_args.flash_attn = "disabled"
|
model_args.flash_attn = "disabled"
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
logger.warning_rank0(
|
||||||
|
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||||
|
)
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
@ -53,18 +56,18 @@ def configure_attn_implementation(
|
|||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
if not is_torch_sdpa_available():
|
if not is_torch_sdpa_available():
|
||||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == "fa2":
|
elif model_args.flash_attn == "fa2":
|
||||||
if not is_flash_attn_2_available():
|
if not is_flash_attn_2_available():
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "flash_attention_2"
|
requested_attn_implementation = "flash_attention_2"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||||
@ -79,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
|||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
if attn_implementation == "flash_attention_2":
|
if attn_implementation == "flash_attention_2":
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
|
||||||
elif attn_implementation == "sdpa":
|
elif attn_implementation == "sdpa":
|
||||||
logger.info("Using torch SDPA for faster training and inference.")
|
logger.info_rank0("Using torch SDPA for faster training and inference.")
|
||||||
else:
|
else:
|
||||||
logger.info("Using vanilla attention implementation.")
|
logger.info_rank0("Using vanilla attention implementation.")
|
||||||
|
@ -19,14 +19,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial, wraps
|
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import LAYERNORM_NAMES
|
from ...extras.constants import LAYERNORM_NAMES
|
||||||
from ...extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||||
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
Only applies gradient checkpointing to trainable layers.
|
Only applies gradient checkpointing to trainable layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(gradient_checkpointing_func)
|
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
|
||||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
if not self.supports_gradient_checkpointing:
|
if not self.supports_gradient_checkpointing:
|
||||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
if gradient_checkpointing_kwargs is None:
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
|
|||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
self.enable_input_require_grads()
|
self.enable_input_require_grads()
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
else: # have already enabled input require gradients
|
else: # have already enabled input require gradients
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
(3) add the upcasting of the lm_head in fp32
|
(3) add the upcasting of the lm_head in fp32
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
if not model_args.disable_gradient_checkpointing:
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
logger.warning("Current model does not support gradient checkpointing.")
|
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||||
else:
|
else:
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info_rank0("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
if model_args.upcast_lmhead_output:
|
if model_args.upcast_lmhead_output:
|
||||||
output_layer = model.get_output_embeddings()
|
output_layer = model.get_output_embeddings()
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
logger.info("Upcasting lm_head outputs in float32.")
|
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||||
|
@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||||
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
|||||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||||
|
|
||||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||||
|
@ -12,9 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -23,10 +24,15 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def apply_liger_kernel(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
is_trainable: bool,
|
||||||
|
require_logits: bool,
|
||||||
|
) -> None:
|
||||||
if not is_trainable or not model_args.enable_liger_kernel:
|
if not is_trainable or not model_args.enable_liger_kernel:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -48,8 +54,14 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
|||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support liger kernel.")
|
logger.warning_rank0("Current model does not support liger kernel.")
|
||||||
return
|
return
|
||||||
|
|
||||||
apply_liger_kernel()
|
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||||
logger.info("Liger kernel has been applied to the model.")
|
logger.info_rank0("Current training stage does not support chunked cross entropy.")
|
||||||
|
kwargs = {"fused_linear_cross_entropy": False}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
apply_liger_kernel(**kwargs)
|
||||||
|
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
Cache,
|
Cache,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
transformers_logger = logging.get_logger(__name__)
|
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
@ -86,7 +86,7 @@ def llama_attention_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
if not is_trainable or not model_args.shift_attn:
|
if not is_trainable or not model_args.shift_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
_apply_llama_patch()
|
_apply_llama_patch()
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning_rank0("Current model does not support shift short attention.")
|
||||||
|
@ -14,14 +14,14 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||||
@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("output_layer")
|
forbidden_modules.add("output_layer")
|
||||||
elif model_type == "internlm2":
|
elif model_type == "internlm2":
|
||||||
forbidden_modules.add("output")
|
forbidden_modules.add("output")
|
||||||
elif model_type in ["llava", "paligemma"]:
|
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||||
forbidden_modules.add("multi_modal_projector")
|
forbidden_modules.add("multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
forbidden_modules.add("merger")
|
forbidden_modules.add("merger")
|
||||||
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||||
return list(module_names)
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
|
|
||||||
if num_layers % num_layer_trainable != 0:
|
if num_layers % num_layer_trainable != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
|
||||||
)
|
)
|
||||||
|
|
||||||
stride = num_layers // num_layer_trainable
|
stride = num_layers // num_layer_trainable
|
||||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
|
||||||
module_names = []
|
module_names = []
|
||||||
for name, _ in model.named_modules():
|
for name, _ in model.named_modules():
|
||||||
if any(target_module in name for target_module in target_modules) and any(
|
if any(target_module in name for target_module in target_modules) and any(
|
||||||
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
):
|
):
|
||||||
module_names.append(name)
|
module_names.append(name)
|
||||||
|
|
||||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
return module_names
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user