mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge branch 'main' into main
Former-commit-id: 5f14910910154ba569435e7e68acbd6c30f79e80
This commit is contained in:
commit
d99e164cad
@ -7,6 +7,8 @@ data
|
|||||||
docker
|
docker
|
||||||
saves
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
|
ms_cache
|
||||||
|
om_cache
|
||||||
output
|
output
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
|
17
.env.local
17
.env.local
@ -1,32 +1,35 @@
|
|||||||
# Note: actually we do not support .env, just for reference
|
# Note: actually we do not support .env, just for reference
|
||||||
# api
|
# api
|
||||||
API_HOST=0.0.0.0
|
API_HOST=
|
||||||
API_PORT=8000
|
API_PORT=
|
||||||
API_KEY=
|
API_KEY=
|
||||||
API_MODEL_NAME=gpt-3.5-turbo
|
API_MODEL_NAME=
|
||||||
FASTAPI_ROOT_PATH=
|
FASTAPI_ROOT_PATH=
|
||||||
|
MAX_CONCURRENT=
|
||||||
# general
|
# general
|
||||||
DISABLE_VERSION_CHECK=
|
DISABLE_VERSION_CHECK=
|
||||||
FORCE_CHECK_IMPORTS=
|
FORCE_CHECK_IMPORTS=
|
||||||
LLAMAFACTORY_VERBOSITY=
|
LLAMAFACTORY_VERBOSITY=
|
||||||
USE_MODELSCOPE_HUB=
|
USE_MODELSCOPE_HUB=
|
||||||
|
USE_OPENMIND_HUB=
|
||||||
RECORD_VRAM=
|
RECORD_VRAM=
|
||||||
# torchrun
|
# torchrun
|
||||||
FORCE_TORCHRUN=
|
FORCE_TORCHRUN=
|
||||||
MASTER_ADDR=
|
MASTER_ADDR=
|
||||||
MASTER_PORT=
|
MASTER_PORT=
|
||||||
NNODES=
|
NNODES=
|
||||||
RANK=
|
NODE_RANK=
|
||||||
NPROC_PER_NODE=
|
NPROC_PER_NODE=
|
||||||
# wandb
|
# wandb
|
||||||
WANDB_DISABLED=
|
WANDB_DISABLED=
|
||||||
WANDB_PROJECT=huggingface
|
WANDB_PROJECT=
|
||||||
WANDB_API_KEY=
|
WANDB_API_KEY=
|
||||||
# gradio ui
|
# gradio ui
|
||||||
GRADIO_SHARE=False
|
GRADIO_SHARE=
|
||||||
GRADIO_SERVER_NAME=0.0.0.0
|
GRADIO_SERVER_NAME=
|
||||||
GRADIO_SERVER_PORT=
|
GRADIO_SERVER_PORT=
|
||||||
GRADIO_ROOT_PATH=
|
GRADIO_ROOT_PATH=
|
||||||
|
GRADIO_IPV6=
|
||||||
# setup
|
# setup
|
||||||
ENABLE_SHORT_CONSOLE=1
|
ENABLE_SHORT_CONSOLE=1
|
||||||
# reserved (do not use)
|
# reserved (do not use)
|
||||||
|
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
|||||||
### Style guide
|
### Style guide
|
||||||
|
|
||||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||||
|
|
||||||
|
### Create a Pull Request
|
||||||
|
|
||||||
|
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||||
|
|
||||||
|
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||||
|
cd LLaMA-Factory
|
||||||
|
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create a new branch to hold your development changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout -b dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Set up a development environment by running the following command in a virtual environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||||
|
|
||||||
|
5. Check code before commit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make commit
|
||||||
|
make style && make quality
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Submit changes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add .
|
||||||
|
git commit -m "commit message"
|
||||||
|
git fetch upstream
|
||||||
|
git rebase upstream/main
|
||||||
|
git push -u origin dev_your_branch
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||||
|
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
- "3.8"
|
- "3.8" # TODO: remove py38 in next transformers release
|
||||||
- "3.9"
|
- "3.9"
|
||||||
- "3.10"
|
- "3.10"
|
||||||
- "3.11"
|
- "3.11"
|
||||||
@ -54,7 +54,6 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install git+https://github.com/huggingface/transformers.git
|
|
||||||
python -m pip install ".[torch,dev]"
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -162,6 +162,7 @@ cython_debug/
|
|||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
ms_cache/
|
ms_cache/
|
||||||
hf_cache/
|
hf_cache/
|
||||||
|
om_cache/
|
||||||
cache/
|
cache/
|
||||||
config/
|
config/
|
||||||
saves/
|
saves/
|
||||||
|
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ['--maxkb=25000']
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-yaml
|
||||||
|
- id: debug-statements
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
args: ['--branch', 'main']
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.17.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py38-plus]
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.9
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
11
Makefile
11
Makefile
@ -1,7 +1,14 @@
|
|||||||
.PHONY: quality style test
|
.PHONY: build commit quality style test
|
||||||
|
|
||||||
check_dirs := scripts src tests setup.py
|
check_dirs := scripts src tests setup.py
|
||||||
|
|
||||||
|
build:
|
||||||
|
pip install build && python -m build
|
||||||
|
|
||||||
|
commit:
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
ruff check $(check_dirs)
|
ruff check $(check_dirs)
|
||||||
ruff format --check $(check_dirs)
|
ruff format --check $(check_dirs)
|
||||||
@ -11,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||||
|
54
README.md
54
README.md
@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#projects-using-llama-factory)
|
[](#projects-using-llama-factory)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -26,10 +26,17 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
|||||||
Choose your path:
|
Choose your path:
|
||||||
|
|
||||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
- **Local machine**: Please refer to [usage](#getting-started)
|
- **Local machine**: Please refer to [usage](#getting-started)
|
||||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
|
|
||||||
|
Recent activities:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [Features](#features)
|
- [Features](#features)
|
||||||
@ -72,6 +79,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||||
|
|
||||||
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||||
|
|
||||||
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||||
@ -130,7 +139,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||||
|
|
||||||
@ -163,7 +172,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
| Model | Model size | Template |
|
| Model | Model size | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
@ -172,20 +181,23 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi-small |
|
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||||
@ -412,7 +424,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
|
|||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|
||||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||||
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||||
|
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||||
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||||
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||||
|
|
||||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||||
|
>
|
||||||
|
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### Download from ModelScope Hub
|
### Download from ModelScope Hub
|
||||||
|
|
||||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
|||||||
|
|
||||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||||
|
|
||||||
|
### Download from Modelers Hub
|
||||||
|
|
||||||
|
You can also use Modelers Hub to download models and datasets.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||||
|
|
||||||
### Use W&B Logger
|
### Use W&B Logger
|
||||||
|
|
||||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||||
@ -684,11 +712,13 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -696,7 +726,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
49
README_zh.md
49
README_zh.md
@ -4,7 +4,7 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||||
[](https://pypi.org/project/llamafactory/)
|
[](https://pypi.org/project/llamafactory/)
|
||||||
[](#使用了-llama-factory-的项目)
|
[](#使用了-llama-factory-的项目)
|
||||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||||
[](https://discord.gg/rKfvV9r9FK)
|
[](https://discord.gg/rKfvV9r9FK)
|
||||||
[](https://twitter.com/llamafactory_ai)
|
[](https://twitter.com/llamafactory_ai)
|
||||||
@ -26,11 +26,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
选择你的打开方式:
|
选择你的打开方式:
|
||||||
|
|
||||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||||
- **PAI-DSW**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
- **PAI-DSW**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||||
- **本地机器**:请见[如何使用](#如何使用)
|
- **本地机器**:请见[如何使用](#如何使用)
|
||||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||||
|
|
||||||
|
近期活动:
|
||||||
|
|
||||||
|
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [项目特色](#项目特色)
|
- [项目特色](#项目特色)
|
||||||
@ -73,6 +80,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||||
|
|
||||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||||
|
|
||||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||||
@ -164,7 +173,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
| 模型名 | 模型大小 | Template |
|
| 模型名 | 模型大小 | Template |
|
||||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||||
@ -173,19 +182,22 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||||
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
|||||||
pip install -e ".[torch,metrics]"
|
pip install -e ".[torch,metrics]"
|
||||||
```
|
```
|
||||||
|
|
||||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||||
@ -412,7 +424,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
|
|
||||||
### 数据准备
|
### 数据准备
|
||||||
|
|
||||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
|||||||
docker run -dit --gpus=all \
|
docker run -dit --gpus=all \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-p 7860:7860 \
|
-p 7860:7860 \
|
||||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
|||||||
docker run -dit \
|
docker run -dit \
|
||||||
-v ./hf_cache:/root/.cache/huggingface \
|
-v ./hf_cache:/root/.cache/huggingface \
|
||||||
-v ./ms_cache:/root/.cache/modelscope \
|
-v ./ms_cache:/root/.cache/modelscope \
|
||||||
|
-v ./om_cache:/root/.cache/openmind \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
-v ./output:/app/output \
|
-v ./output:/app/output \
|
||||||
-v ./saves:/app/saves \
|
-v ./saves:/app/saves \
|
||||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
|||||||
|
|
||||||
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||||
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||||
|
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||||
- `data`:宿主机中存放数据集的文件夹路径。
|
- `data`:宿主机中存放数据集的文件夹路径。
|
||||||
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||||
|
|
||||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||||
|
>
|
||||||
|
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
||||||
|
|
||||||
### 从魔搭社区下载
|
### 从魔搭社区下载
|
||||||
|
|
||||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
|||||||
|
|
||||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||||
|
|
||||||
|
### 从魔乐社区下载
|
||||||
|
|
||||||
|
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||||
|
```
|
||||||
|
|
||||||
|
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||||
|
|
||||||
### 使用 W&B 面板
|
### 使用 W&B 面板
|
||||||
|
|
||||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||||
@ -684,11 +712,12 @@ run_name: test_run # 可选
|
|||||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||||
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -696,7 +725,7 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 145 KiB After Width: | Height: | Size: 132 KiB |
Binary file not shown.
Before Width: | Height: | Size: 149 KiB After Width: | Height: | Size: 132 KiB |
@ -17,9 +17,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||||
_LICENSE = "gpl-3.0"
|
_LICENSE = "gpl-3.0"
|
||||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||||
|
|
||||||
|
|
||||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||||
|
|
||||||
def _generate_examples(self, filepath: str):
|
def _generate_examples(self, filepath: str):
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
conversations = []
|
conversations = []
|
||||||
|
@ -54,7 +54,8 @@
|
|||||||
},
|
},
|
||||||
"alpaca_en": {
|
"alpaca_en": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_en",
|
"hf_hub_url": "llamafactory/alpaca_en",
|
||||||
"ms_hub_url": "llamafactory/alpaca_en"
|
"ms_hub_url": "llamafactory/alpaca_en",
|
||||||
|
"om_hub_url": "HaM/alpaca_en"
|
||||||
},
|
},
|
||||||
"alpaca_zh": {
|
"alpaca_zh": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_zh",
|
"hf_hub_url": "llamafactory/alpaca_zh",
|
||||||
@ -66,7 +67,8 @@
|
|||||||
},
|
},
|
||||||
"alpaca_gpt4_zh": {
|
"alpaca_gpt4_zh": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||||
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
|
"ms_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||||
|
"om_hub_url": "State_Cloud/alpaca-gpt4-data-zh"
|
||||||
},
|
},
|
||||||
"glaive_toolcall_en": {
|
"glaive_toolcall_en": {
|
||||||
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
||||||
|
@ -8,9 +8,9 @@ import datasets
|
|||||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||||
_CITATION = ""
|
_CITATION = ""
|
||||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||||
_LICENSE = "mit"
|
_LICENSE = "mit"
|
||||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||||
_URLS = {
|
_URLS = {
|
||||||
"train": [
|
"train": [
|
||||||
_URL + "harmless-base/train.jsonl.gz",
|
_URL + "harmless-base/train.jsonl.gz",
|
||||||
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
|||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
key = 0
|
key = 0
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
chosen = data["chosen"]
|
chosen = data["chosen"]
|
||||||
|
@ -20,9 +20,9 @@ _CITATION = """\
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||||
_LICENSE = "cc-by-nc-4.0"
|
_LICENSE = "cc-by-nc-4.0"
|
||||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||||
|
|
||||||
|
|
||||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||||||
|
|
||||||
def _generate_examples(self, filepaths: List[str]):
|
def _generate_examples(self, filepaths: List[str]):
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, encoding="utf-8") as f:
|
||||||
for row in f:
|
for row in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||||
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Define environments
|
# Define environments
|
||||||
ENV MAX_JOBS=4
|
ENV MAX_JOBS=4
|
||||||
@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
|
ARG INSTALL_EETQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||||
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
|
@ -8,11 +8,15 @@ services:
|
|||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: false
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: false
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
|
INSTALL_EETQ: false
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
ports:
|
ports:
|
||||||
|
@ -10,6 +10,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- /usr/local/dcmi:/usr/local/dcmi
|
- /usr/local/dcmi:/usr/local/dcmi
|
||||||
|
@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
|
|||||||
ARG INSTALL_VLLM=false
|
ARG INSTALL_VLLM=false
|
||||||
ARG INSTALL_DEEPSPEED=false
|
ARG INSTALL_DEEPSPEED=false
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
|
ARG INSTALL_LIGER_KERNEL=false
|
||||||
|
ARG INSTALL_HQQ=false
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
|
|
||||||
# Set the working directory
|
# Set the working directory
|
||||||
@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||||
|
fi; \
|
||||||
|
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||||
|
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||||
|
fi; \
|
||||||
pip install -e ".[$EXTRA_PACKAGES]"
|
pip install -e ".[$EXTRA_PACKAGES]"
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
|
@ -8,11 +8,14 @@ services:
|
|||||||
INSTALL_VLLM: false
|
INSTALL_VLLM: false
|
||||||
INSTALL_DEEPSPEED: false
|
INSTALL_DEEPSPEED: false
|
||||||
INSTALL_FLASHATTN: false
|
INSTALL_FLASHATTN: false
|
||||||
|
INSTALL_LIGER_KERNEL: false
|
||||||
|
INSTALL_HQQ: false
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory
|
container_name: llamafactory
|
||||||
volumes:
|
volumes:
|
||||||
- ../../hf_cache:/root/.cache/huggingface
|
- ../../hf_cache:/root/.cache/huggingface
|
||||||
- ../../ms_cache:/root/.cache/modelscope
|
- ../../ms_cache:/root/.cache/modelscope
|
||||||
|
- ../../om_cache:/root/.cache/openmind
|
||||||
- ../../data:/app/data
|
- ../../data:/app/data
|
||||||
- ../../output:/app/output
|
- ../../output:/app/output
|
||||||
- ../../saves:/app/saves
|
- ../../saves:/app/saves
|
||||||
|
@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
|||||||
df = pd.read_csv(filepath, header=None)
|
df = pd.read_csv(filepath, header=None)
|
||||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||||
|
|
||||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
yield from enumerate(df.to_dict(orient="records"))
|
||||||
yield i, instance
|
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
datasets>=2.16.0,<=2.21.0
|
datasets>=2.16.0,<=3.0.2
|
||||||
accelerate>=0.30.1,<=0.34.2
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0,<5.0.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
@ -19,3 +19,4 @@ fire
|
|||||||
packaging
|
packaging
|
||||||
pyyaml
|
pyyaml
|
||||||
numpy<2.0.0
|
numpy<2.0.0
|
||||||
|
av
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the imoneoi's OpenChat library.
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
@ -74,7 +73,7 @@ def calculate_lr(
|
|||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
|||||||
elif "4090" in device_name:
|
elif "4090" in device_name:
|
||||||
return 98 * 1e12 * world_size
|
return 98 * 1e12 * world_size
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||||
|
|
||||||
|
|
||||||
def calculate_mfu(
|
def calculate_mfu(
|
||||||
@ -140,10 +139,10 @@ def calculate_mfu(
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
if deepspeed_stage in [2, 3]:
|
if deepspeed_stage in [2, 3]:
|
||||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||||
|
|
||||||
run_exp(args)
|
run_exp(args)
|
||||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||||
result = json.load(f)
|
result = json.load(f)
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
@ -157,7 +156,7 @@ def calculate_mfu(
|
|||||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
/ compute_device_flops(world_size)
|
/ compute_device_flops(world_size)
|
||||||
)
|
)
|
||||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -100,7 +99,7 @@ def calculate_ppl(
|
|||||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
@ -125,8 +124,8 @@ def calculate_ppl(
|
|||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||||
print("Perplexities have been saved at {}.".format(save_name))
|
print(f"Perplexities have been saved at {save_name}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -61,7 +60,7 @@ def length_cdf(
|
|||||||
for length, count in length_tuples:
|
for length, count in length_tuples:
|
||||||
count_accu += count
|
count_accu += count
|
||||||
prob_accu += count / total_num * 100
|
prob_accu += count / total_num * 100
|
||||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
@ -40,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||||
|
|
||||||
|
|
||||||
def block_expansion(
|
def block_expansion(
|
||||||
@ -76,27 +75,27 @@ def block_expansion(
|
|||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
if num_layers % num_expand != 0:
|
if num_layers % num_expand != 0:
|
||||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||||
|
|
||||||
split = num_layers // num_expand
|
split = num_layers // num_expand
|
||||||
layer_cnt = 0
|
layer_cnt = 0
|
||||||
output_state_dict = OrderedDict()
|
output_state_dict = OrderedDict()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||||
|
|
||||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
if (i + 1) % split == 0:
|
if (i + 1) % split == 0:
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if ".{:d}.".format(i) in key:
|
if f".{i:d}." in key:
|
||||||
if "down_proj" in key or "o_proj" in key:
|
if "down_proj" in key or "o_proj" in key:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||||
else:
|
else:
|
||||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||||
|
|
||||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||||
layer_cnt += 1
|
layer_cnt += 1
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
@ -113,17 +112,17 @@ def block_expansion(
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("finetuning_type: freeze")
|
print("finetuning_type: freeze")
|
||||||
print("freeze_trainable_layers: {}".format(num_expand))
|
print(f"freeze_trainable_layers: {num_expand}")
|
||||||
print("use_llama_pro: true")
|
print("use_llama_pro: true")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str):
|
def save_config(input_dir: str, output_dir: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||||
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_baichuan2(
|
def llamafy_baichuan2(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
elif "lm_head" in key:
|
elif "lm_head" in key:
|
||||||
llama2_state_dict[key] = value
|
llama2_state_dict[key] = value
|
||||||
else:
|
else:
|
||||||
raise KeyError("Unable to process key {}".format(key))
|
raise KeyError(f"Unable to process key {key}")
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||||
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
|||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||||
else:
|
else:
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
json.dump(index, f, indent=2, sort_keys=True)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
return str(torch_dtype).replace("torch.", "")
|
return str(torch_dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||||
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
|||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
json.dump(llama2_config_dict, f, indent=2)
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||||
|
|
||||||
|
|
||||||
def llamafy_qwen(
|
def llamafy_qwen(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@ -70,19 +69,19 @@ def quantize_loftq(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(loftq_dir))
|
print(f"Adapter weights saved in {loftq_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
print(f"adapter_name_or_path: {loftq_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("quantization_bit: {}".format(loftq_bits))
|
print(f"quantization_bit: {loftq_bits}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the HuggingFace's PEFT library.
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
@ -54,7 +53,7 @@ def quantize_pissa(
|
|||||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
target_modules=lora_target,
|
target_modules=lora_target,
|
||||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init PiSSA model
|
# Init PiSSA model
|
||||||
@ -65,17 +64,17 @@ def quantize_pissa(
|
|||||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||||
print("Adapter weights saved in {}".format(pissa_dir))
|
print(f"Adapter weights saved in {pissa_dir}")
|
||||||
|
|
||||||
# Save base model
|
# Save base model
|
||||||
base_model: "PreTrainedModel" = peft_model.unload()
|
base_model: "PreTrainedModel" = peft_model.unload()
|
||||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
print("Model weights saved in {}".format(output_dir))
|
print(f"Model weights saved in {output_dir}")
|
||||||
|
|
||||||
print("- Fine-tune this model with:")
|
print("- Fine-tune this model with:")
|
||||||
print("model_name_or_path: {}".format(output_dir))
|
print(f"model_name_or_path: {output_dir}")
|
||||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
print(f"adapter_name_or_path: {pissa_dir}")
|
||||||
print("finetuning_type: lora")
|
print("finetuning_type: lora")
|
||||||
print("pissa_init: false")
|
print("pissa_init: false")
|
||||||
print("pissa_convert: true")
|
print("pissa_convert: true")
|
||||||
|
65
scripts/test_image.py
Normal file
65
scripts/test_image.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||||
|
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Output the color and number of each box."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 1:", result.choices[0].message.content)
|
||||||
|
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What kind of flower is this?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = client.chat.completions.create(messages=messages, model="test")
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
|
print("Round 2:", result.choices[0].message.content)
|
||||||
|
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
11
setup.py
11
setup.py
@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
|
|
||||||
def get_version() -> str:
|
def get_version() -> str:
|
||||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||||
(version,) = re.findall(pattern, file_content)
|
(version,) = re.findall(pattern, file_content)
|
||||||
@ -28,7 +28,7 @@ def get_version() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_requires() -> List[str]:
|
def get_requires() -> List[str]:
|
||||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
with open("requirements.txt", encoding="utf-8") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||||
return lines
|
return lines
|
||||||
@ -54,13 +54,14 @@ extra_require = {
|
|||||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
"vllm": ["vllm>=0.4.3,<=0.6.0"],
|
"vllm": ["vllm>=0.4.3,<=0.6.3"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
"adam-mini": ["adam-mini"],
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"dev": ["ruff", "pytest"],
|
"openmind": ["openmind"],
|
||||||
|
"dev": ["pre-commit", "ruff", "pytest"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ def main():
|
|||||||
author="hiyouga",
|
author="hiyouga",
|
||||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||||
description="Easy-to-use LLM fine-tuning framework",
|
description="Easy-to-use LLM fine-tuning framework",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||||
license="Apache 2.0 License",
|
license="Apache 2.0 License",
|
||||||
|
@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,17 +20,17 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
datasets>=2.16.0,<=2.21.0
|
datasets>=2.16.0,<=3.0.2
|
||||||
accelerate>=0.30.1,<=0.34.2
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
attention:
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.41.2,<=4.45.0
|
transformers>=4.41.2,<=4.46.1
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
|
|||||||
Force using torchrun: FORCE_TORCHRUN=1
|
Force using torchrun: FORCE_TORCHRUN=1
|
||||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||||
|
Use openmind: USE_OPENMIND_HUB=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .extras.env import VERSION
|
from .extras.env import VERSION
|
||||||
|
@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
|||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY", None)
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
def run_api() -> None:
|
def run_api() -> None:
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
@ -21,7 +21,7 @@ import uuid
|
|||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
ROLE_MAPPING = {
|
ROLE_MAPPING = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
@ -69,8 +69,8 @@ ROLE_MAPPING = {
|
|||||||
|
|
||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
@ -84,7 +84,7 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
input_messages = []
|
input_messages = []
|
||||||
image = None
|
images = []
|
||||||
for i, message in enumerate(request.messages):
|
for i, message in enumerate(request.messages):
|
||||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
@ -111,7 +111,7 @@ def _process_request(
|
|||||||
else: # web uri
|
else: # web uri
|
||||||
image_stream = requests.get(image_url, stream=True).raw
|
image_stream = requests.get(image_url, stream=True).raw
|
||||||
|
|
||||||
image = Image.open(image_stream).convert("RGB")
|
images.append(Image.open(image_stream).convert("RGB"))
|
||||||
else:
|
else:
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ def _process_request(
|
|||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
|
|
||||||
return input_messages, system, tools, image
|
return input_messages, system, tools, images or None
|
||||||
|
|
||||||
|
|
||||||
def _create_stream_chat_completion_chunk(
|
def _create_stream_chat_completion_chunk(
|
||||||
@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
|
|||||||
async def create_chat_completion_response(
|
async def create_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> "ChatCompletionResponse":
|
) -> "ChatCompletionResponse":
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
responses = await chat_model.achat(
|
responses = await chat_model.achat(
|
||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@ -169,7 +169,7 @@ async def create_chat_completion_response(
|
|||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
function = Function(name=tool[0], arguments=tool[1])
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||||
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
@ -193,8 +193,8 @@ async def create_chat_completion_response(
|
|||||||
async def create_stream_chat_completion_response(
|
async def create_stream_chat_completion_response(
|
||||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
input_messages, system, tools, image = _process_request(request)
|
input_messages, system, tools, images = _process_request(request)
|
||||||
if tools:
|
if tools:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
|
|||||||
input_messages,
|
input_messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
do_sample=request.do_sample,
|
do_sample=request.do_sample,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
|
|||||||
async def create_score_evaluation_response(
|
async def create_score_evaluation_response(
|
||||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||||
) -> "ScoreEvaluationResponse":
|
) -> "ScoreEvaluationResponse":
|
||||||
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
@ -66,8 +66,8 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
@ -81,8 +81,8 @@ class BaseEngine(ABC):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -53,7 +53,7 @@ class ChatModel:
|
|||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == "vllm":
|
||||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||||
|
|
||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
@ -64,15 +64,15 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Gets a list of responses of the chat model.
|
Gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(
|
task = asyncio.run_coroutine_threadsafe(
|
||||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||||
)
|
)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
@ -81,28 +81,28 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets a list of responses of the chat model.
|
Asynchronously gets a list of responses of the chat model.
|
||||||
"""
|
"""
|
||||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
r"""
|
r"""
|
||||||
Gets the response token-by-token of the chat model.
|
Gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
@ -115,14 +115,14 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
r"""
|
r"""
|
||||||
Asynchronously gets the response token-by-token of the chat model.
|
Asynchronously gets the response token-by-token of the chat model.
|
||||||
"""
|
"""
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||||
yield new_token
|
yield new_token
|
||||||
|
|
||||||
def get_scores(
|
def get_scores(
|
||||||
|
@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceEngine(BaseEngine):
|
class HuggingfaceEngine(BaseEngine):
|
||||||
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("There is no current event loop, creating a new one.")
|
logger.warning_once("There is no current event loop, creating a new one.")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||||
if image is not None:
|
if images is not None:
|
||||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
if video is not None:
|
if videos is not None:
|
||||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||||
|
|
||||||
messages = template.mm_plugin.process_messages(
|
messages = template.mm_plugin.process_messages(
|
||||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||||
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||||
|
|
||||||
generating_args = generating_args.copy()
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif not isinstance(value, torch.Tensor):
|
||||||
|
value = torch.tensor(value)
|
||||||
|
|
||||||
gen_kwargs[key] = value.to(model.device)
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
template,
|
||||||
|
generating_args,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images,
|
||||||
|
videos,
|
||||||
|
input_kwargs,
|
||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
tools,
|
tools,
|
||||||
image,
|
images,
|
||||||
video,
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
|
@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VllmEngine(BaseEngine):
|
class VllmEngine(BaseEngine):
|
||||||
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
|
|||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
import vllm.model_executor.models.llava
|
import vllm.model_executor.models.llava
|
||||||
|
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
|
|
||||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||||
@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
if image is not None:
|
if images is not None:
|
||||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
system = system or self.generating_args["default_system"]
|
||||||
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
|
|||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image is not None: # add image features
|
if images is not None: # add image features
|
||||||
|
image_data = []
|
||||||
|
for image in images:
|
||||||
if not isinstance(image, (str, ImageObject)):
|
if not isinstance(image, (str, ImageObject)):
|
||||||
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||||
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
multi_modal_data = {"image": image}
|
image_data.append(image)
|
||||||
|
|
||||||
|
multi_modal_data = {"image": image_data}
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
final_output = None
|
final_output = None
|
||||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for request_output in generator:
|
async for request_output in generator:
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
|
||||||
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["ImageInput"] = None,
|
images: Optional[Sequence["ImageInput"]] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
videos: Optional[Sequence["VideoInput"]] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||||
async for result in generator:
|
async for result in generator:
|
||||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||||
generated_text = result.outputs[0].text
|
generated_text = result.outputs[0].text
|
||||||
|
@ -22,8 +22,8 @@ from . import launcher
|
|||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
from .chat.chat_model import run_chat
|
from .chat.chat_model import run_chat
|
||||||
from .eval.evaluator import run_eval
|
from .eval.evaluator import run_eval
|
||||||
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.logging import get_logger
|
|
||||||
from .extras.misc import get_device_count
|
from .extras.misc import get_device_count
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
@ -47,7 +47,7 @@ USAGE = (
|
|||||||
WELCOME = (
|
WELCOME = (
|
||||||
"-" * 58
|
"-" * 58
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||||
+ " " * (21 - len(VERSION))
|
+ " " * (21 - len(VERSION))
|
||||||
+ "|\n|"
|
+ "|\n|"
|
||||||
+ " " * 56
|
+ " " * 56
|
||||||
@ -56,7 +56,7 @@ WELCOME = (
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -86,19 +86,19 @@ def main():
|
|||||||
elif command == Command.EXPORT:
|
elif command == Command.EXPORT:
|
||||||
export_model()
|
export_model()
|
||||||
elif command == Command.TRAIN:
|
elif command == Command.TRAIN:
|
||||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||||
if force_torchrun or get_device_count() > 1:
|
if force_torchrun or get_device_count() > 1:
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||||
process = subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
).format(
|
).format(
|
||||||
nnodes=os.environ.get("NNODES", "1"),
|
nnodes=os.getenv("NNODES", "1"),
|
||||||
node_rank=os.environ.get("RANK", "0"),
|
node_rank=os.getenv("NODE_RANK", "0"),
|
||||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
file_name=launcher.__file__,
|
file_name=launcher.__file__,
|
||||||
@ -118,4 +118,4 @@ def main():
|
|||||||
elif command == Command.HELP:
|
elif command == Command.HELP:
|
||||||
print(USAGE)
|
print(USAGE)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
raise NotImplementedError(f"Unknown command: {command}.")
|
||||||
|
@ -16,7 +16,7 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
|
|
||||||
|
|
||||||
@ -29,45 +29,51 @@ if TYPE_CHECKING:
|
|||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_images(
|
def _convert_images(
|
||||||
images: Sequence["ImageInput"],
|
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Optional[List["ImageInput"]]:
|
) -> Optional[List["ImageInput"]]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
if len(images) == 0:
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
elif len(images) == 0:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
images = images[:]
|
images = images[:]
|
||||||
|
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
for i in range(len(images)):
|
for i in range(len(images)):
|
||||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||||
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def _convert_videos(
|
def _convert_videos(
|
||||||
videos: Sequence["VideoInput"],
|
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Optional[List["VideoInput"]]:
|
) -> Optional[List["VideoInput"]]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
if len(videos) == 0:
|
if not isinstance(videos, list):
|
||||||
|
videos = [videos]
|
||||||
|
elif len(videos) == 0:
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
videos = videos[:]
|
videos = videos[:]
|
||||||
|
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
for i in range(len(videos)):
|
for i in range(len(videos)):
|
||||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||||
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@ -161,7 +167,7 @@ def convert_sharegpt(
|
|||||||
broken_data = False
|
broken_data = False
|
||||||
for turn_idx, message in enumerate(messages):
|
for turn_idx, message in enumerate(messages):
|
||||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
logger.warning("Invalid role tag in {}.".format(messages))
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
aligned_messages.append(
|
aligned_messages.append(
|
||||||
@ -171,7 +177,7 @@ def convert_sharegpt(
|
|||||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
):
|
):
|
||||||
logger.warning("Invalid message count in {}.".format(messages))
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||||
@ -192,7 +198,7 @@ def convert_sharegpt(
|
|||||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
):
|
):
|
||||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
prompt = aligned_messages
|
prompt = aligned_messages
|
||||||
@ -205,7 +211,7 @@ def convert_sharegpt(
|
|||||||
response = aligned_messages[-1:]
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
if broken_data:
|
if broken_data:
|
||||||
logger.warning("Skipping this abnormal example.")
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
prompt, response = [], []
|
prompt, response = [], []
|
||||||
|
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
features.update(mm_inputs)
|
features.update(mm_inputs)
|
||||||
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
|||||||
for key in ("chosen", "rejected"):
|
for key in ("chosen", "rejected"):
|
||||||
for feature in features:
|
for feature in features:
|
||||||
target_feature = {
|
target_feature = {
|
||||||
"input_ids": feature["{}_input_ids".format(key)],
|
"input_ids": feature[f"{key}_input_ids"],
|
||||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
"attention_mask": feature[f"{key}_attention_mask"],
|
||||||
"labels": feature["{}_labels".format(key)],
|
"labels": feature[f"{key}_labels"],
|
||||||
"images": feature["images"],
|
"images": feature["images"],
|
||||||
"videos": feature["videos"],
|
"videos": feature["videos"],
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
|||||||
|
|
||||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
@ -56,12 +56,12 @@ def merge_dataset(
|
|||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
|
||||||
return concatenate_datasets(all_datasets)
|
return concatenate_datasets(all_datasets)
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
@ -70,7 +70,7 @@ def merge_dataset(
|
|||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
|
@ -83,14 +83,14 @@ class StringFormatter(Formatter):
|
|||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
for name, value in kwargs.items():
|
for name, value in kwargs.items():
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise RuntimeError("Expected a string, got {}".format(value))
|
raise RuntimeError(f"Expected a string, got {value}")
|
||||||
|
|
||||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
|
|||||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
functions = []
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for name, arguments in functions:
|
for name, arguments in functions:
|
||||||
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
|
|||||||
elif isinstance(slot, (dict, set)):
|
elif isinstance(slot, (dict, set)):
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
|
|||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return [""]
|
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
@ -20,8 +20,8 @@ import numpy as np
|
|||||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import has_tokenized_data
|
from ..extras.misc import has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset, split_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_single_dataset(
|
def _load_single_dataset(
|
||||||
@ -51,9 +51,9 @@ def _load_single_dataset(
|
|||||||
r"""
|
r"""
|
||||||
Loads a single dataset and aligns it to the standard format.
|
Loads a single dataset and aligns it to the standard format.
|
||||||
"""
|
"""
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
data_name = dataset_attr.subset
|
data_name = dataset_attr.subset
|
||||||
data_dir = dataset_attr.folder
|
data_dir = dataset_attr.folder
|
||||||
@ -69,25 +69,24 @@ def _load_single_dataset(
|
|||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
|
||||||
raise ValueError("File types should be identical.")
|
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("File {} not found.".format(local_path))
|
raise ValueError(f"File {local_path} not found.")
|
||||||
|
|
||||||
|
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||||
|
|
||||||
|
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||||
|
raise ValueError("File types should be identical.")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset # type: ignore
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
@ -98,10 +97,27 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
|
|
||||||
|
elif dataset_attr.load_from == "om_hub":
|
||||||
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
|
from openmind import OmDataset # type: ignore
|
||||||
|
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
|
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||||
|
dataset = OmDataset.load_dataset(
|
||||||
|
path=data_path,
|
||||||
|
name=data_name,
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
split=dataset_attr.split,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
token=model_args.om_hub_token,
|
||||||
|
streaming=data_args.streaming,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
@ -111,13 +127,10 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=data_args.streaming,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
|
||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
target_num = dataset_attr.num_samples
|
target_num = dataset_attr.num_samples
|
||||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||||
@ -128,7 +141,7 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||||
dataset = dataset.select(indexes)
|
dataset = dataset.select(indexes)
|
||||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
@ -224,9 +237,9 @@ def get_dataset(
|
|||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||||
|
|
||||||
dataset_module: Dict[str, "Dataset"] = {}
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
if "train" in dataset_dict:
|
if "train" in dataset_dict:
|
||||||
@ -277,8 +290,8 @@ def get_dataset(
|
|||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -110,7 +111,7 @@ class BasePlugin:
|
|||||||
image = Image.open(image["path"])
|
image = Image.open(image["path"])
|
||||||
|
|
||||||
if not isinstance(image, ImageObject):
|
if not isinstance(image, ImageObject):
|
||||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||||
|
|
||||||
results.append(self._preprocess_image(image, **kwargs))
|
results.append(self._preprocess_image(image, **kwargs))
|
||||||
|
|
||||||
@ -157,6 +158,7 @@ class BasePlugin:
|
|||||||
It holds num_patches == torch.prod(image_grid_thw)
|
It holds num_patches == torch.prod(image_grid_thw)
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||||
input_dict = {"images": None} # default key
|
input_dict = {"images": None} # default key
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@ -174,10 +176,16 @@ class BasePlugin:
|
|||||||
)
|
)
|
||||||
input_dict["videos"] = videos
|
input_dict["videos"] = videos
|
||||||
|
|
||||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
mm_inputs = {}
|
||||||
return image_processor(**input_dict, return_tensors="pt")
|
if image_processor != video_processor:
|
||||||
else:
|
if input_dict.get("images") is not None:
|
||||||
return {}
|
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||||
|
if input_dict.get("videos") is not None:
|
||||||
|
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||||
|
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||||
|
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
@ -218,6 +226,14 @@ class BasePlugin:
|
|||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
Builds batched multimodal inputs for VLMs.
|
Builds batched multimodal inputs for VLMs.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
images: a list of image inputs, shape (num_images,)
|
||||||
|
videos: a list of video inputs, shape (num_videos,)
|
||||||
|
imglens: number of images in each sample, shape (batch_size,)
|
||||||
|
vidlens: number of videos in each sample, shape (batch_size,)
|
||||||
|
seqlens: number of tokens in each sample, shape (batch_size,)
|
||||||
|
processor: a processor for pre-processing images and videos
|
||||||
"""
|
"""
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return {}
|
return {}
|
||||||
@ -245,7 +261,123 @@ class LlavaPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "image_sizes" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
res = self._get_mm_inputs(images, videos, processor)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if "pixel_values_videos" in mm_inputs:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.video_token in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -284,7 +416,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", "")
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -324,6 +456,68 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
patch_size = getattr(processor, "patch_size")
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
image_break_token = getattr(processor, "image_break_token")
|
||||||
|
image_end_token = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if image_input_sizes is None:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // patch_size
|
||||||
|
num_width_tokens = width // patch_size
|
||||||
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if mm_inputs.get("pixel_values"):
|
||||||
|
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||||
|
|
||||||
|
mm_inputs.pop("image_sizes", None)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
@ -369,7 +563,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if num_image_tokens >= len(image_grid_thw):
|
if num_image_tokens >= len(image_grid_thw):
|
||||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
@ -382,7 +576,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
if num_video_tokens >= len(video_grid_thw):
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER,
|
VIDEO_PLACEHOLDER,
|
||||||
@ -396,10 +590,73 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
if len(videos) != num_video_tokens:
|
||||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLlavaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
num_frames = 0
|
||||||
|
exist_images = "pixel_values_images" in mm_inputs
|
||||||
|
exist_videos = "pixel_values_videos" in mm_inputs
|
||||||
|
if exist_videos or exist_images:
|
||||||
|
if exist_images:
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
|
num_frames = 1
|
||||||
|
if exist_videos:
|
||||||
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
|
video_seqlen = image_seqlen * num_frames
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
|
while self.video_token in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
|
||||||
|
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -420,8 +677,12 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
|
"llava_next": LlavaNextPlugin,
|
||||||
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
|
"video_llava": VideoLlavaPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -432,6 +693,6 @@ def get_mm_plugin(
|
|||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
plugin_class = PLUGINS.get(name, None)
|
plugin_class = PLUGINS.get(name, None)
|
||||||
if plugin_class is None:
|
if plugin_class is None:
|
||||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||||
|
|
||||||
return plugin_class(image_token, video_token)
|
return plugin_class(image_token, video_token)
|
||||||
|
@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
|
|||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,7 +30,7 @@ class DatasetAttr:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# basic configs
|
# basic configs
|
||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||||
dataset_name: str
|
dataset_name: str
|
||||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
|||||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r") as f:
|
with open(config_path) as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||||
|
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
dataset_list: List["DatasetAttr"] = []
|
dataset_list: List["DatasetAttr"] = []
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
if dataset_info is None: # dataset_dir is ONLINE
|
if dataset_info is None: # dataset_dir is ONLINE
|
||||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
if use_modelscope():
|
||||||
|
load_from = "ms_hub"
|
||||||
|
elif use_openmind():
|
||||||
|
load_from = "om_hub"
|
||||||
|
else:
|
||||||
|
load_from = "hf_hub"
|
||||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||||
|
|
||||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||||
|
has_om_url = "om_hub_url" in dataset_info[name]
|
||||||
|
|
||||||
if has_hf_url or has_ms_url:
|
if has_hf_url or has_ms_url or has_om_url:
|
||||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
if has_ms_url and (use_modelscope() or not has_hf_url):
|
||||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||||
|
elif has_om_url and (use_openmind() or not has_hf_url):
|
||||||
|
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
elif "script_url" in dataset_info[name]:
|
elif "script_url" in dataset_info[name]:
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_feedback_example(
|
def _encode_feedback_example(
|
||||||
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||||
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
|||||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||||
if desirable_num == 0 or undesirable_num == 0:
|
if desirable_num == 0 or undesirable_num == 0:
|
||||||
logger.warning("Your dataset only has one preference type.")
|
logger.warning_rank0("Your dataset only has one preference type.")
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_pairwise_example(
|
def _encode_pairwise_example(
|
||||||
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||||
@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
|
|||||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_supervised_example(
|
def _encode_supervised_example(
|
||||||
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
|||||||
length2indexes = defaultdict(list)
|
length2indexes = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
)
|
)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
if length > data_args.cutoff_len:
|
if length > data_args.cutoff_len:
|
||||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||||
else:
|
else:
|
||||||
lengths.append(length)
|
lengths.append(length)
|
||||||
length2indexes[length].append(valid_num)
|
length2indexes[length].append(valid_num)
|
||||||
@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
|||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ..data_utils import Role
|
from ..data_utils import Role
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_unsupervised_example(
|
def _encode_unsupervised_example(
|
||||||
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1:
|
if len(examples["_prompt"][i]) % 2 != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_unsupervised_example(
|
input_ids, labels = _encode_unsupervised_example(
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
from .mm_plugin import get_mm_plugin
|
from .mm_plugin import get_mm_plugin
|
||||||
@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
|||||||
from .mm_plugin import BasePlugin
|
from .mm_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -49,6 +49,7 @@ class Template:
|
|||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
|
replace_jinja_template: bool
|
||||||
mm_plugin: "BasePlugin"
|
mm_plugin: "BasePlugin"
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
@ -146,7 +147,7 @@ class Template:
|
|||||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||||
token_ids += [tokenizer.eos_token_id]
|
token_ids += [tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@ -214,6 +215,7 @@ def _register_template(
|
|||||||
stop_words: Sequence[str] = [],
|
stop_words: Sequence[str] = [],
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
|
replace_jinja_template: bool = True,
|
||||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
@ -263,6 +265,7 @@ def _register_template(
|
|||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
|
replace_jinja_template=replace_jinja_template,
|
||||||
mm_plugin=mm_plugin,
|
mm_plugin=mm_plugin,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
|||||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||||
|
|
||||||
if is_added:
|
if is_added:
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||||
else:
|
else:
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||||
|
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
|
||||||
def _jinja_escape(content: str) -> str:
|
def _jinja_escape(content: str) -> str:
|
||||||
@ -353,24 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
r"""
|
r"""
|
||||||
Gets chat template and fixes the tokenizer.
|
Gets chat template and fixes the tokenizer.
|
||||||
"""
|
"""
|
||||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
|
||||||
require_version(
|
|
||||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
|
||||||
)
|
|
||||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
template = TEMPLATES["empty"] # placeholder
|
||||||
else:
|
else:
|
||||||
template = TEMPLATES.get(data_args.template, None)
|
template = TEMPLATES.get(data_args.template, None)
|
||||||
if template is None:
|
if template is None:
|
||||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||||
|
|
||||||
|
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
if data_args.tool_format is not None:
|
if data_args.tool_format is not None:
|
||||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
@ -388,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||||
try:
|
try:
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
logger.info("Cannot add this chat template to tokenizer.")
|
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
@ -640,6 +641,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="exaone",
|
||||||
|
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||||
|
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="falcon",
|
name="falcon",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||||
@ -664,6 +673,7 @@ _register_template(
|
|||||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -681,6 +691,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="index",
|
||||||
|
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||||
|
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||||
|
efficient_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
@ -740,6 +758,7 @@ _register_template(
|
|||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -754,6 +773,107 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_llama3",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
(
|
||||||
|
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|eot_id|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_mistral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_qwen",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_yi",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video_mistral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="llava_next_video_yi",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
@ -831,6 +951,14 @@ _register_template(
|
|||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="pixtral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
@ -840,6 +968,7 @@ _register_template(
|
|||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -852,6 +981,7 @@ _register_template(
|
|||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
|
replace_jinja_template=False,
|
||||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -907,6 +1037,17 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="video_llava",
|
||||||
|
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="xuanyuan",
|
name="xuanyuan",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||||
|
@ -177,6 +177,6 @@ TOOLS = {
|
|||||||
def get_tool_utils(name: str) -> "ToolUtils":
|
def get_tool_utils(name: str) -> "ToolUtils":
|
||||||
tool_utils = TOOLS.get(name, None)
|
tool_utils = TOOLS.get(name, None)
|
||||||
if tool_utils is None:
|
if tool_utils is None:
|
||||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
raise ValueError(f"Tool utils `{name}` not found.")
|
||||||
|
|
||||||
return tool_utils
|
return tool_utils
|
||||||
|
@ -87,7 +87,7 @@ class Evaluator:
|
|||||||
token=self.model_args.hf_hub_token,
|
token=self.model_args.hf_hub_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, encoding="utf-8") as f:
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||||
@ -139,7 +139,7 @@ class Evaluator:
|
|||||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join(
|
score_info = "\n".join(
|
||||||
[
|
[
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||||
for category_name, category_correct in category_corrects.items()
|
for category_name, category_correct in category_corrects.items()
|
||||||
if len(category_correct)
|
if len(category_correct)
|
||||||
]
|
]
|
||||||
|
@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
|
|||||||
|
|
||||||
def get_eval_template(name: str) -> "EvalTemplate":
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
eval_template = eval_templates.get(name, None)
|
eval_template = eval_templates.get(name, None)
|
||||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
assert eval_template is not None, f"Template {name} does not exist."
|
||||||
return eval_template
|
return eval_template
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
|
|||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
IMAGE_PLACEHOLDER = "<image>"
|
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
|||||||
|
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = "<video>"
|
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
@ -107,6 +108,7 @@ VISION_MODELS = set()
|
|||||||
class DownloadSource(str, Enum):
|
class DownloadSource(str, Enum):
|
||||||
DEFAULT = "hf"
|
DEFAULT = "hf"
|
||||||
MODELSCOPE = "ms"
|
MODELSCOPE = "ms"
|
||||||
|
OPENMIND = "om"
|
||||||
|
|
||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
@ -114,17 +116,12 @@ def register_model_group(
|
|||||||
template: Optional[str] = None,
|
template: Optional[str] = None,
|
||||||
vision: bool = False,
|
vision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
prefix = None
|
|
||||||
for name, path in models.items():
|
for name, path in models.items():
|
||||||
if prefix is None:
|
|
||||||
prefix = name.split("-")[0]
|
|
||||||
else:
|
|
||||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
|
||||||
SUPPORTED_MODELS[name] = path
|
SUPPORTED_MODELS[name] = path
|
||||||
if template is not None:
|
if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
|
||||||
DEFAULT_TEMPLATE[prefix] = template
|
DEFAULT_TEMPLATE[name] = template
|
||||||
if vision:
|
if vision:
|
||||||
VISION_MODELS.add(prefix)
|
VISION_MODELS.add(name)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
@ -168,14 +165,17 @@ register_model_group(
|
|||||||
"Baichuan2-13B-Base": {
|
"Baichuan2-13B-Base": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
|
||||||
},
|
},
|
||||||
"Baichuan2-7B-Chat": {
|
"Baichuan2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
|
||||||
},
|
},
|
||||||
"Baichuan2-13B-Chat": {
|
"Baichuan2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="baichuan2",
|
template="baichuan2",
|
||||||
@ -274,27 +274,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-1.3B": {
|
"Chinese-Llama-2-1.3B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B": {
|
"Chinese-Llama-2-7B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B": {
|
"Chinese-Llama-2-13B": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-1.3B-Chat": {
|
"Chinese-Alpaca-2-1.3B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-7B-Chat": {
|
"Chinese-Alpaca-2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||||
},
|
},
|
||||||
"ChineseLLaMA2-13B-Chat": {
|
"Chinese-Alpaca-2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||||
},
|
},
|
||||||
@ -450,25 +450,25 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"DeepSeekCoder-6.7B-Base": {
|
"DeepSeek-Coder-6.7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-7B-Base": {
|
"DeepSeek-Coder-7B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Base": {
|
"DeepSeek-Coder-33B-Base": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-6.7B-Instruct": {
|
"DeepSeek-Coder-6.7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-7B-Instruct": {
|
"DeepSeek-Coder-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
||||||
},
|
},
|
||||||
"DeepSeekCoder-33B-Instruct": {
|
"DeepSeek-Coder-33B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||||
},
|
},
|
||||||
@ -477,6 +477,16 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"EXAONE-3.0-7.8B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="exaone",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Falcon-7B": {
|
"Falcon-7B": {
|
||||||
@ -550,10 +560,12 @@ register_model_group(
|
|||||||
"Gemma-2-2B-Instruct": {
|
"Gemma-2-2B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
|
||||||
},
|
},
|
||||||
"Gemma-2-9B-Instruct": {
|
"Gemma-2-9B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
|
||||||
},
|
},
|
||||||
"Gemma-2-27B-Instruct": {
|
"Gemma-2-27B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||||
@ -573,6 +585,7 @@ register_model_group(
|
|||||||
"GLM-4-9B-Chat": {
|
"GLM-4-9B-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
||||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
|
||||||
},
|
},
|
||||||
"GLM-4-9B-1M-Chat": {
|
"GLM-4-9B-1M-Chat": {
|
||||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
||||||
@ -583,6 +596,33 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Index-1.9B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Character-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Base-Pure": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
|
||||||
|
},
|
||||||
|
"Index-1.9B-Chat-32K": {
|
||||||
|
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
|
||||||
|
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="index",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
@ -624,16 +664,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||||
},
|
},
|
||||||
},
|
|
||||||
template="intern2",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
|
||||||
models={
|
|
||||||
"InternLM2.5-1.8B": {
|
"InternLM2.5-1.8B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B": {
|
"InternLM2.5-7B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
||||||
@ -642,22 +676,27 @@ register_model_group(
|
|||||||
"InternLM2.5-20B": {
|
"InternLM2.5-20B": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
|
||||||
},
|
},
|
||||||
"InternLM2.5-1.8B-Chat": {
|
"InternLM2.5-1.8B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B-Chat": {
|
"InternLM2.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
|
||||||
},
|
},
|
||||||
"InternLM2.5-7B-1M-Chat": {
|
"InternLM2.5-7B-1M-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
|
||||||
},
|
},
|
||||||
"InternLM2.5-20B-Chat": {
|
"InternLM2.5-20B-Chat": {
|
||||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
|
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
|
||||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
|
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
|
||||||
|
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="intern2",
|
template="intern2",
|
||||||
@ -686,19 +725,19 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA-7B": {
|
"Llama-7B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
||||||
},
|
},
|
||||||
"LLaMA-13B": {
|
"Llama-13B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||||
},
|
},
|
||||||
"LLaMA-30B": {
|
"Llama-30B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||||
},
|
},
|
||||||
"LLaMA-65B": {
|
"Llama-65B": {
|
||||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||||
},
|
},
|
||||||
@ -708,27 +747,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA2-7B": {
|
"Llama-2-7B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B": {
|
"Llama-2-13B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B": {
|
"Llama-2-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-7B-Chat": {
|
"Llama-2-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-13B-Chat": {
|
"Llama-2-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||||
},
|
},
|
||||||
"LLaMA2-70B-Chat": {
|
"Llama-2-70B-Chat": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||||
},
|
},
|
||||||
@ -739,60 +778,78 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA3-8B": {
|
"Llama-3-8B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B": {
|
"Llama-3-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||||
},
|
},
|
||||||
"LLaMA3-8B-Instruct": {
|
"Llama-3-8B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B-Instruct": {
|
"Llama-3-70B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-8B-Chinese-Chat": {
|
"Llama-3-8B-Chinese-Chat": {
|
||||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3-70B-Chinese-Chat": {
|
"Llama-3-70B-Chinese-Chat": {
|
||||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
||||||
},
|
},
|
||||||
},
|
"Llama-3.1-8B": {
|
||||||
template="llama3",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
|
||||||
models={
|
|
||||||
"LLaMA3.1-8B": {
|
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-70B": {
|
"Llama-3.1-70B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-405B": {
|
"Llama-3.1-405B": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-8B-Instruct": {
|
"Llama-3.1-8B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-70B-Instruct": {
|
"Llama-3.1-70B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
||||||
},
|
},
|
||||||
"LLaMA3.1-405B-Instruct": {
|
"Llama-3.1-405B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
|
||||||
},
|
},
|
||||||
|
"Llama-3.1-8B-Chinese-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
|
||||||
|
},
|
||||||
|
"Llama-3.1-70B-Chinese-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
|
||||||
|
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
|
||||||
|
},
|
||||||
|
"Llama-3.2-1B": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
|
||||||
|
},
|
||||||
|
"Llama-3.2-3B": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
|
||||||
|
},
|
||||||
|
"Llama-3.2-1B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
|
||||||
|
},
|
||||||
|
"Llama-3.2-3B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="llama3",
|
template="llama3",
|
||||||
)
|
)
|
||||||
@ -800,11 +857,13 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaVA1.5-7B-Chat": {
|
"LLaVA-1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
|
||||||
},
|
},
|
||||||
"LLaVA1.5-13B-Chat": {
|
"LLaVA-1.5-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="llava",
|
template="llava",
|
||||||
@ -812,6 +871,117 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-13B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Mistral-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_mistral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Llama3-8B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_llama3",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-34B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_yi",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-72B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-110B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_qwen",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-Video-7B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-7B-32k-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video_mistral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LLaVA-NeXT-Video-34B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
|
||||||
|
},
|
||||||
|
"LLaVA-NeXT-Video-34B-DPO-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llava_next_video_yi",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiniCPM-2B-SFT-Chat": {
|
"MiniCPM-2B-SFT-Chat": {
|
||||||
@ -832,6 +1002,7 @@ register_model_group(
|
|||||||
"MiniCPM3-4B-Chat": {
|
"MiniCPM3-4B-Chat": {
|
||||||
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
||||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="cpm3",
|
template="cpm3",
|
||||||
@ -1005,27 +1176,27 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Phi3-4B-4k-Instruct": {
|
"Phi-3-4B-4k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-4B-128k-Instruct": {
|
"Phi-3-4B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-7B-8k-Instruct": {
|
"Phi-3-7B-8k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-7B-128k-Instruct": {
|
"Phi-3-7B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-14B-8k-Instruct": {
|
"Phi-3-14B-8k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
||||||
},
|
},
|
||||||
"Phi3-14B-128k-Instruct": {
|
"Phi-3-14B-128k-Instruct": {
|
||||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||||
},
|
},
|
||||||
@ -1034,6 +1205,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Pixtral-12B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="pixtral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen-1.8B": {
|
"Qwen-1.8B": {
|
||||||
@ -1068,35 +1251,35 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int8-Chat": {
|
"Qwen-1.8B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-1.8B-int4-Chat": {
|
"Qwen-1.8B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int8-Chat": {
|
"Qwen-7B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-7B-int4-Chat": {
|
"Qwen-7B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int8-Chat": {
|
"Qwen-14B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-14B-int4-Chat": {
|
"Qwen-14B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int8-Chat": {
|
"Qwen-72B-Chat-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||||
},
|
},
|
||||||
"Qwen-72B-int4-Chat": {
|
"Qwen-72B-Chat-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||||
},
|
},
|
||||||
@ -1179,75 +1362,75 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen1.5-0.5B-int8-Chat": {
|
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-0.5B-int4-Chat": {
|
"Qwen1.5-0.5B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int8-Chat": {
|
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-1.8B-int4-Chat": {
|
"Qwen1.5-1.8B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int8-Chat": {
|
"Qwen1.5-4B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-4B-int4-Chat": {
|
"Qwen1.5-4B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int8-Chat": {
|
"Qwen1.5-7B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-7B-int4-Chat": {
|
"Qwen1.5-7B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int8-Chat": {
|
"Qwen1.5-14B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-14B-int4-Chat": {
|
"Qwen1.5-14B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-32B-int4-Chat": {
|
"Qwen1.5-32B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int8-Chat": {
|
"Qwen1.5-72B-Chat-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen1.5-72B-int4-Chat": {
|
"Qwen1.5-72B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-110B-int4-Chat": {
|
"Qwen1.5-110B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B": {
|
"CodeQwen1.5-7B": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B-Chat": {
|
"CodeQwen1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||||
},
|
},
|
||||||
"Qwen1.5-Code-7B-int4-Chat": {
|
"CodeQwen1.5-7B-Chat-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||||
},
|
},
|
||||||
@ -1281,14 +1464,17 @@ register_model_group(
|
|||||||
"Qwen2-0.5B-Instruct": {
|
"Qwen2-0.5B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-1.5B-Instruct": {
|
"Qwen2-1.5B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-7B-Instruct": {
|
"Qwen2-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2-72B-Instruct": {
|
"Qwen2-72B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
||||||
@ -1568,51 +1754,53 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen2VL-2B-Instruct": {
|
"Qwen2-VL-2B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct": {
|
"Qwen2-VL-7B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct": {
|
"Qwen2-VL-72B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-2B-Instruct-AWQ": {
|
"Qwen2-VL-2B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-7B-Instruct-AWQ": {
|
"Qwen2-VL-7B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-GPTQ-Int8": {
|
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-GPTQ-Int4": {
|
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||||
},
|
},
|
||||||
"Qwen2VL-72B-Instruct-AWQ": {
|
"Qwen2-VL-72B-Instruct-AWQ": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
|
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||||
},
|
},
|
||||||
@ -1673,10 +1861,12 @@ register_model_group(
|
|||||||
"TeleChat-7B-Chat": {
|
"TeleChat-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||||
|
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
|
||||||
},
|
},
|
||||||
"TeleChat-12B-Chat": {
|
"TeleChat-12B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
||||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
||||||
|
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
|
||||||
},
|
},
|
||||||
"TeleChat-12B-v2-Chat": {
|
"TeleChat-12B-v2-Chat": {
|
||||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||||
@ -1689,11 +1879,11 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": {
|
"Vicuna-v1.5-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||||
},
|
},
|
||||||
"Vicuna1.5-13B-Chat": {
|
"Vicuna-v1.5-13B-Chat": {
|
||||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||||
},
|
},
|
||||||
@ -1702,6 +1892,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Video-LLaVA-7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="video_llava",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XuanYuan-6B": {
|
"XuanYuan-6B": {
|
||||||
@ -1712,7 +1913,7 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B": {
|
"XuanYuan2-70B": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
||||||
},
|
},
|
||||||
@ -1724,31 +1925,31 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-Chat": {
|
"XuanYuan2-70B-Chat": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||||
},
|
},
|
||||||
"XuanYuan-6B-int8-Chat": {
|
"XuanYuan-6B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-6B-int4-Chat": {
|
"XuanYuan-6B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-70B-int8-Chat": {
|
"XuanYuan-70B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-70B-int4-Chat": {
|
"XuanYuan-70B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-int8-Chat": {
|
"XuanYuan2-70B-Chat-8bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||||
},
|
},
|
||||||
"XuanYuan-2-70B-int4-Chat": {
|
"XuanYuan2-70B-Chat-4bit": {
|
||||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||||
},
|
},
|
||||||
@ -1853,19 +2054,19 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
||||||
},
|
},
|
||||||
"Yi-6B-int8-Chat": {
|
"Yi-6B-Chat-8bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||||
},
|
},
|
||||||
"Yi-6B-int4-Chat": {
|
"Yi-6B-Chat-4bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||||
},
|
},
|
||||||
"Yi-34B-int8-Chat": {
|
"Yi-34B-Chat-8bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||||
},
|
},
|
||||||
"Yi-34B-int4-Chat": {
|
"Yi-34B-Chat-4bits": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||||
},
|
},
|
||||||
@ -1884,6 +2085,7 @@ register_model_group(
|
|||||||
"Yi-1.5-6B-Chat": {
|
"Yi-1.5-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
||||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
||||||
|
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
|
||||||
},
|
},
|
||||||
"Yi-1.5-9B-Chat": {
|
"Yi-1.5-9B-Chat": {
|
||||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
||||||
@ -1916,10 +2118,10 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"YiVL-6B-Chat": {
|
"Yi-VL-6B-Chat": {
|
||||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
||||||
},
|
},
|
||||||
"YiVL-34B-Chat": {
|
"Yi-VL-34B-Chat": {
|
||||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -72,4 +72,4 @@ def print_env() -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||||
|
@ -20,6 +20,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .constants import RUNNING_LOG
|
from .constants import RUNNING_LOG
|
||||||
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
|
|||||||
|
|
||||||
def __init__(self, output_dir: str) -> None:
|
def __init__(self, output_dir: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
formatter = logging.Formatter(
|
self._formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
self.setLevel(logging.INFO)
|
self.setLevel(logging.INFO)
|
||||||
self.setFormatter(formatter)
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||||
if os.path.exists(self.running_log):
|
if os.path.exists(self.running_log):
|
||||||
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
|
|||||||
if record.name == "httpx":
|
if record.name == "httpx":
|
||||||
return
|
return
|
||||||
|
|
||||||
log_entry = self.format(record)
|
log_entry = self._formatter.format(record)
|
||||||
self.thread_pool.submit(self._write_log, log_entry)
|
self.thread_pool.submit(self._write_log, log_entry)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
|
|||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class _Logger(logging.Logger):
|
||||||
|
r"""
|
||||||
|
A logger that supports info_rank0 and warning_once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def info_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_once(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _get_default_logging_level() -> "logging._Level":
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
r"""
|
r"""
|
||||||
Returns the default logging level.
|
Returns the default logging level.
|
||||||
@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
|
|||||||
if env_level_str.upper() in logging._nameToLevel:
|
if env_level_str.upper() in logging._nameToLevel:
|
||||||
return logging._nameToLevel[env_level_str.upper()]
|
return logging._nameToLevel[env_level_str.upper()]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||||
|
|
||||||
return _default_log_level
|
return _default_log_level
|
||||||
|
|
||||||
@ -84,7 +99,7 @@ def _get_library_name() -> str:
|
|||||||
return __name__.split(".")[0]
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_library_root_logger() -> "logging.Logger":
|
def _get_library_root_logger() -> "_Logger":
|
||||||
return logging.getLogger(_get_library_name())
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
|
|||||||
global _default_handler
|
global _default_handler
|
||||||
|
|
||||||
with _thread_lock:
|
with _thread_lock:
|
||||||
if _default_handler:
|
if _default_handler: # already configured
|
||||||
return
|
return
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
_default_handler = logging.StreamHandler(sys.stdout)
|
_default_handler = logging.StreamHandler(sys.stdout)
|
||||||
_default_handler.setFormatter(formatter)
|
_default_handler.setFormatter(formatter)
|
||||||
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||||
r"""
|
r"""
|
||||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||||
"""
|
"""
|
||||||
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
|||||||
|
|
||||||
_configure_library_root_logger()
|
_configure_library_root_logger()
|
||||||
return logging.getLogger(name)
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: "logging.Handler") -> None:
|
||||||
|
r"""
|
||||||
|
Adds a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
r"""
|
||||||
|
Removes a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(None)
|
||||||
|
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.info_rank0 = info_rank0
|
||||||
|
logging.Logger.warning_rank0 = warning_rank0
|
||||||
|
logging.Logger.warning_once = warning_once
|
||||||
|
@ -32,7 +32,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
@ -76,12 +76,12 @@ def check_dependencies() -> None:
|
|||||||
r"""
|
r"""
|
||||||
Checks the version of the required packages.
|
Checks the version of the required packages.
|
||||||
"""
|
"""
|
||||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
else:
|
else:
|
||||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
||||||
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
|
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||||
|
|
||||||
@ -231,18 +231,35 @@ def torch_gc() -> None:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
|
||||||
return model_args.model_name_or_path
|
return model_args.model_name_or_path
|
||||||
|
|
||||||
try:
|
if use_modelscope():
|
||||||
from modelscope import snapshot_download
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
|
from modelscope import snapshot_download # type: ignore
|
||||||
|
|
||||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||||
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
return snapshot_download(
|
||||||
except ImportError:
|
model_args.model_name_or_path,
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
revision=revision,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_openmind():
|
||||||
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
|
from openmind.utils.hub import snapshot_download # type: ignore
|
||||||
|
|
||||||
|
return snapshot_download(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def use_modelscope() -> bool:
|
def use_modelscope() -> bool:
|
||||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||||
|
|
||||||
|
|
||||||
|
def use_openmind() -> bool:
|
||||||
|
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||||
|
@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
|
|||||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_transformers_version_equal_to_4_46():
|
||||||
|
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
from .packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if is_matplotlib_available():
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def smooth(scalars: List[float]) -> List[float]:
|
def smooth(scalars: List[float]) -> List[float]:
|
||||||
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
|||||||
Plots loss curves and saves the image.
|
Plots loss curves and saves the image.
|
||||||
"""
|
"""
|
||||||
plt.switch_backend("agg")
|
plt.switch_backend("agg")
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
|||||||
metrics.append(data["log_history"][i][key])
|
metrics.append(data["log_history"][i][key])
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
logger.warning(f"No metric {key} to plot.")
|
logger.warning_rank0(f"No metric {key} to plot.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||||
plt.title("training {} of {}".format(key, save_dictionary))
|
plt.title(f"training {key} of {save_dictionary}")
|
||||||
plt.xlabel("step")
|
plt.xlabel("step")
|
||||||
plt.ylabel(key)
|
plt.ylabel(key)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
|
@ -41,6 +41,10 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
|
image_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
|
||||||
|
)
|
||||||
cutoff_len: int = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
@ -111,7 +115,13 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
tokenized_path: Optional[str] = field(
|
tokenized_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to save or load the tokenized datasets. "
|
||||||
|
"If tokenized_path not exists, it will save the tokenized datasets. "
|
||||||
|
"If tokenized_path exists, it will load the tokenized datasets."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -123,6 +133,9 @@ class DataArguments:
|
|||||||
self.dataset = split_arg(self.dataset)
|
self.dataset = split_arg(self.dataset)
|
||||||
self.eval_dataset = split_arg(self.eval_dataset)
|
self.eval_dataset = split_arg(self.eval_dataset)
|
||||||
|
|
||||||
|
if self.image_dir is None:
|
||||||
|
self.image_dir = self.dataset_dir
|
||||||
|
|
||||||
if self.dataset is None and self.val_size > 1e-6:
|
if self.dataset is None and self.val_size > 1e-6:
|
||||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -267,6 +267,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
)
|
)
|
||||||
|
om_hub_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||||
|
)
|
||||||
print_param_status: bool = field(
|
print_param_status: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
@ -308,20 +312,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
|||||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||||
arg_dict = old_arg.to_dict()
|
init_args, lazy_args = {}, {}
|
||||||
arg_dict.update(**kwargs)
|
for attr in fields(source):
|
||||||
for attr in fields(cls):
|
if attr.init:
|
||||||
if not attr.init:
|
init_args[attr.name] = getattr(source, attr.name)
|
||||||
arg_dict.pop(attr.name)
|
else:
|
||||||
|
lazy_args[attr.name] = getattr(source, attr.name)
|
||||||
|
|
||||||
new_arg = cls(**arg_dict)
|
init_args.update(kwargs)
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
result = cls(**init_args)
|
||||||
new_arg.device_map = old_arg.device_map
|
for name, value in lazy_args.items():
|
||||||
new_arg.model_max_length = old_arg.model_max_length
|
setattr(result, name, value)
|
||||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
|
||||||
return new_arg
|
return result
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES
|
from ..extras.constants import CHECKPOINT_NAMES
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
|||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
|
|
||||||
if unknown_args:
|
if unknown_args:
|
||||||
print(parser.format_help())
|
print(parser.format_help())
|
||||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
return (*parsed_args,)
|
return (*parsed_args,)
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
def _set_transformers_logging() -> None:
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ def _verify_model_args(
|
|||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
|
|
||||||
@ -123,7 +122,7 @@ def _check_extra_dependencies(
|
|||||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
|
require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
|
||||||
|
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||||
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.neat_packing and not data_args.packing:
|
if data_args.neat_packing and not data_args.packing:
|
||||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||||
data_args.packing = True
|
data_args.packing = True
|
||||||
|
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and model_args.resize_vocab
|
and model_args.resize_vocab
|
||||||
and finetuning_args.additional_target is None
|
and finetuning_args.additional_target is None
|
||||||
):
|
):
|
||||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
logger.warning_rank0(
|
||||||
|
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
logger.warning_rank0(
|
||||||
|
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||||
|
)
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
can_resume_from_checkpoint = False
|
can_resume_from_checkpoint = False
|
||||||
if training_args.resume_from_checkpoint is not None:
|
if training_args.resume_from_checkpoint is not None:
|
||||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||||
training_args.resume_from_checkpoint = None
|
training_args.resume_from_checkpoint = None
|
||||||
else:
|
else:
|
||||||
can_resume_from_checkpoint = True
|
can_resume_from_checkpoint = True
|
||||||
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args.resume_from_checkpoint = last_checkpoint
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage in ["rm", "ppo"]
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
and training_args.resume_from_checkpoint is not None
|
and training_args.resume_from_checkpoint is not None
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .model_utils.quantization import QuantizationMethod
|
from .model_utils.quantization import QuantizationMethod
|
||||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _setup_full_tuning(
|
def _setup_full_tuning(
|
||||||
@ -45,7 +45,7 @@ def _setup_full_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info_rank0("Fine-tuning method: Full")
|
||||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info_rank0("Fine-tuning method: Freeze")
|
||||||
if hasattr(model.config, "text_config"): # composite models
|
if hasattr(model.config, "text_config"): # composite models
|
||||||
config = getattr(model.config, "text_config")
|
config = getattr(model.config, "text_config")
|
||||||
else:
|
else:
|
||||||
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||||
|
|
||||||
|
|
||||||
def _setup_lora_tuning(
|
def _setup_lora_tuning(
|
||||||
@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
|||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PeftModel":
|
) -> "PeftModel":
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
|||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||||
|
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
|||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
finetuning_args.additional_target = module_names
|
finetuning_args.additional_target = module_names
|
||||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
@ -236,11 +236,11 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
if finetuning_args.pissa_init:
|
if finetuning_args.pissa_init:
|
||||||
if finetuning_args.pissa_iter == -1:
|
if finetuning_args.pissa_iter == -1:
|
||||||
logger.info("Using PiSSA initialization.")
|
logger.info_rank0("Using PiSSA initialization.")
|
||||||
peft_kwargs["init_lora_weights"] = "pissa"
|
peft_kwargs["init_lora_weights"] = "pissa"
|
||||||
else:
|
else:
|
||||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
@ -284,11 +284,11 @@ def init_adapter(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
pass
|
pass
|
||||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info_rank0("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
@ -300,6 +300,6 @@ def init_adapter(
|
|||||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -18,15 +18,15 @@ import torch
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
from .model_utils.misc import register_autoclass
|
from .model_utils.misc import register_autoclass
|
||||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||||
from .model_utils.valuehead import load_valuehead_params
|
from .model_utils.valuehead import load_valuehead_params
|
||||||
from .model_utils.visual import get_image_seqlen
|
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerModule(TypedDict):
|
class TokenizerModule(TypedDict):
|
||||||
@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
Note: including inplace operation of model_args.
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
skip_check_imports()
|
skip_check_imports()
|
||||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
|
||||||
return {
|
return {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
|
|
||||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained tokenizer.
|
Loads pretrained tokenizer and optionally loads processor.
|
||||||
|
|
||||||
Note: including inplace operation of model_args.
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
@ -82,33 +82,30 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise OSError("Failed to load tokenizer.") from e
|
||||||
|
|
||||||
if model_args.new_special_tokens is not None:
|
if model_args.new_special_tokens is not None:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
replace_additional_special_tokens=False,
|
replace_additional_special_tokens=False,
|
||||||
)
|
)
|
||||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||||
model_args.resize_vocab = True
|
model_args.resize_vocab = True
|
||||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
patch_processor(processor, config, tokenizer, model_args)
|
||||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
except Exception as e:
|
||||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
logger.debug(f"Processor was not found: {e}.")
|
||||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
|
||||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
|
||||||
except Exception:
|
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
# Avoid load tokenizer, see:
|
# Avoid load tokenizer, see:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||||
if "Processor" not in processor.__class__.__name__:
|
if processor is not None and "Processor" not in processor.__class__.__name__:
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
return {"tokenizer": tokenizer, "processor": processor}
|
return {"tokenizer": tokenizer, "processor": processor}
|
||||||
@ -135,6 +132,7 @@ def load_model(
|
|||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
@ -157,7 +155,7 @@ def load_model(
|
|||||||
load_class = AutoModelForCausalLM
|
load_class = AutoModelForCausalLM
|
||||||
|
|
||||||
if model_args.train_from_scratch:
|
if model_args.train_from_scratch:
|
||||||
model = load_class.from_config(config)
|
model = load_class.from_config(config, trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
model = load_class.from_pretrained(**init_kwargs)
|
model = load_class.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
@ -182,7 +180,7 @@ def load_model(
|
|||||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
@ -200,9 +198,9 @@ def load_model(
|
|||||||
trainable_params, all_param, 100 * trainable_params / all_param
|
trainable_params, all_param, 100 * trainable_params / all_param
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param_stats = "all params: {:,}".format(all_param)
|
param_stats = f"all params: {all_param:,}"
|
||||||
|
|
||||||
logger.info(param_stats)
|
logger.info_rank0(param_stats)
|
||||||
|
|
||||||
if model_args.print_param_status:
|
if model_args.print_param_status:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
|
|||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_attn_implementation(
|
def configure_attn_implementation(
|
||||||
@ -37,13 +37,16 @@ def configure_attn_implementation(
|
|||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
if model_args.flash_attn != "fa2":
|
||||||
|
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
model_args.flash_attn = "fa2"
|
model_args.flash_attn = "fa2"
|
||||||
else:
|
else:
|
||||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||||
model_args.flash_attn = "disabled"
|
model_args.flash_attn = "disabled"
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
logger.warning_rank0(
|
||||||
|
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||||
|
)
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
@ -53,18 +56,18 @@ def configure_attn_implementation(
|
|||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
if not is_torch_sdpa_available():
|
if not is_torch_sdpa_available():
|
||||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == "fa2":
|
elif model_args.flash_attn == "fa2":
|
||||||
if not is_flash_attn_2_available():
|
if not is_flash_attn_2_available():
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "flash_attention_2"
|
requested_attn_implementation = "flash_attention_2"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||||
@ -79,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
|||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
if attn_implementation == "flash_attention_2":
|
if attn_implementation == "flash_attention_2":
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
|
||||||
elif attn_implementation == "sdpa":
|
elif attn_implementation == "sdpa":
|
||||||
logger.info("Using torch SDPA for faster training and inference.")
|
logger.info_rank0("Using torch SDPA for faster training and inference.")
|
||||||
else:
|
else:
|
||||||
logger.info("Using vanilla attention implementation.")
|
logger.info_rank0("Using vanilla attention implementation.")
|
||||||
|
@ -19,14 +19,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial, wraps
|
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import LAYERNORM_NAMES
|
from ...extras.constants import LAYERNORM_NAMES
|
||||||
from ...extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||||
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
Only applies gradient checkpointing to trainable layers.
|
Only applies gradient checkpointing to trainable layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(gradient_checkpointing_func)
|
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
|
||||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
if not self.supports_gradient_checkpointing:
|
if not self.supports_gradient_checkpointing:
|
||||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
if gradient_checkpointing_kwargs is None:
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
|
|||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
self.enable_input_require_grads()
|
self.enable_input_require_grads()
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
else: # have already enabled input require gradients
|
else: # have already enabled input require gradients
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
(3) add the upcasting of the lm_head in fp32
|
(3) add the upcasting of the lm_head in fp32
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
if not model_args.disable_gradient_checkpointing:
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
logger.warning("Current model does not support gradient checkpointing.")
|
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||||
else:
|
else:
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info_rank0("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
if model_args.upcast_lmhead_output:
|
if model_args.upcast_lmhead_output:
|
||||||
output_layer = model.get_output_embeddings()
|
output_layer = model.get_output_embeddings()
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
logger.info("Upcasting lm_head outputs in float32.")
|
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||||
|
@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||||
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
|||||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||||
|
|
||||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||||
|
@ -12,9 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -23,10 +24,15 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def apply_liger_kernel(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
is_trainable: bool,
|
||||||
|
require_logits: bool,
|
||||||
|
) -> None:
|
||||||
if not is_trainable or not model_args.enable_liger_kernel:
|
if not is_trainable or not model_args.enable_liger_kernel:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -48,8 +54,14 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
|||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support liger kernel.")
|
logger.warning_rank0("Current model does not support liger kernel.")
|
||||||
return
|
return
|
||||||
|
|
||||||
apply_liger_kernel()
|
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||||
logger.info("Liger kernel has been applied to the model.")
|
logger.info_rank0("Current training stage does not support chunked cross entropy.")
|
||||||
|
kwargs = {"fused_linear_cross_entropy": False}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
apply_liger_kernel(**kwargs)
|
||||||
|
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
Cache,
|
Cache,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
transformers_logger = logging.get_logger(__name__)
|
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
@ -86,7 +86,7 @@ def llama_attention_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
if not is_trainable or not model_args.shift_attn:
|
if not is_trainable or not model_args.shift_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
_apply_llama_patch()
|
_apply_llama_patch()
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning_rank0("Current model does not support shift short attention.")
|
||||||
|
@ -14,14 +14,14 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||||
@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("output_layer")
|
forbidden_modules.add("output_layer")
|
||||||
elif model_type == "internlm2":
|
elif model_type == "internlm2":
|
||||||
forbidden_modules.add("output")
|
forbidden_modules.add("output")
|
||||||
elif model_type in ["llava", "paligemma"]:
|
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||||
forbidden_modules.add("multi_modal_projector")
|
forbidden_modules.add("multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
forbidden_modules.add("merger")
|
forbidden_modules.add("merger")
|
||||||
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||||
return list(module_names)
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
|
|
||||||
if num_layers % num_layer_trainable != 0:
|
if num_layers % num_layer_trainable != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
|
||||||
)
|
)
|
||||||
|
|
||||||
stride = num_layers // num_layer_trainable
|
stride = num_layers // num_layer_trainable
|
||||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
|
||||||
module_names = []
|
module_names = []
|
||||||
for name, _ in model.named_modules():
|
for name, _ in model.named_modules():
|
||||||
if any(target_module in name for target_module in target_modules) and any(
|
if any(target_module in name for target_module in target_modules) and any(
|
||||||
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
):
|
):
|
||||||
module_names.append(name)
|
module_names.append(name)
|
||||||
|
|
||||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
return module_names
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
|
|
||||||
|
|
||||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
if is_transformers_version_greater_than_4_43():
|
if is_transformers_version_greater_than_4_43():
|
||||||
import transformers.modeling_flash_attention_utils
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||||
_patch_for_block_diag_attn(model_type)
|
_patch_for_block_diag_attn(model_type)
|
||||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Current model does not support block diagonal attention.")
|
raise ValueError("Current model does not support block diagonal attention.")
|
||||||
|
@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import FILEEXT2TYPE
|
from ...extras.constants import FILEEXT2TYPE
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.misc import get_current_device
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -109,7 +109,7 @@ def configure_quantization(
|
|||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # ptq
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||||
@ -130,7 +130,7 @@ def configure_quantization(
|
|||||||
quantization_config["bits"] = 2
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
quant_bits = quantization_config.get("bits", "?")
|
quant_bits = quantization_config.get("bits", "?")
|
||||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||||
@ -149,7 +149,7 @@ def configure_quantization(
|
|||||||
)
|
)
|
||||||
init_kwargs["device_map"] = "auto"
|
init_kwargs["device_map"] = "auto"
|
||||||
init_kwargs["max_memory"] = get_max_memory()
|
init_kwargs["max_memory"] = get_max_memory()
|
||||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # on-the-fly
|
elif model_args.quantization_bit is not None: # on-the-fly
|
||||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||||
@ -179,7 +179,7 @@ def configure_quantization(
|
|||||||
else:
|
else:
|
||||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||||
|
|
||||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||||
@ -191,7 +191,7 @@ def configure_quantization(
|
|||||||
init_kwargs["quantization_config"] = HqqConfig(
|
init_kwargs["quantization_config"] = HqqConfig(
|
||||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||||
) # use ATEN kernel (axis=0) for performance
|
) # use ATEN kernel (axis=0) for performance
|
||||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||||
if model_args.quantization_bit != 8:
|
if model_args.quantization_bit != 8:
|
||||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||||
@ -201,4 +201,4 @@ def configure_quantization(
|
|||||||
|
|
||||||
require_version("eetq", "To fix: pip install eetq")
|
require_version("eetq", "To fix: pip install eetq")
|
||||||
init_kwargs["quantization_config"] = EetqConfig()
|
init_kwargs["quantization_config"] = EetqConfig()
|
||||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not hasattr(config, "rope_scaling"):
|
if not hasattr(config, "rope_scaling"):
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_args.model_max_length is not None:
|
if model_args.model_max_length is not None:
|
||||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
)
|
)
|
||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
logger.info(
|
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||||
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
|
|
||||||
)
|
|
||||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
else:
|
else:
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
else:
|
else:
|
||||||
scaling_factor = 2.0
|
scaling_factor = 2.0
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ...extras.misc import get_current_device
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_unsloth_kwargs(
|
def _get_unsloth_kwargs(
|
||||||
@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
|
|||||||
try:
|
try:
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
model = None
|
model = None
|
||||||
model_args.use_unsloth = False
|
model_args.use_unsloth = False
|
||||||
|
|
||||||
|
@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
import torch
|
import torch
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ...extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
err_text = str(err)
|
err_text = str(err)
|
||||||
|
|
||||||
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
|
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,11 +18,11 @@
|
|||||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import transformers.models
|
import transformers.models
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -31,8 +31,8 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
transformers_logger = logging.get_logger(__name__)
|
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||||
@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
|
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
|
||||||
|
|
||||||
@ -108,11 +108,18 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
Patches VLMs before loading them.
|
Patches VLMs before loading them.
|
||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if model_type == "llava": # required for ds zero3 and valuehead models
|
if model_type in [
|
||||||
|
"llava",
|
||||||
|
"llava_next",
|
||||||
|
"llava_next_video",
|
||||||
|
"paligemma",
|
||||||
|
"pixtral",
|
||||||
|
"video_llava",
|
||||||
|
]: # required for ds zero3 and valuehead models
|
||||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
@ -122,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
@ -150,12 +157,28 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
|||||||
image_seqlen += 1
|
image_seqlen += 1
|
||||||
elif model_type == "paligemma":
|
elif model_type == "paligemma":
|
||||||
image_seqlen = config.vision_config.num_image_tokens
|
image_seqlen = config.vision_config.num_image_tokens
|
||||||
elif model_type == "qwen2_vl": # variable length
|
else:
|
||||||
image_seqlen = -1
|
image_seqlen = -1
|
||||||
|
|
||||||
return image_seqlen
|
return image_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
def get_patch_size(config: "PretrainedConfig") -> int:
|
||||||
|
r"""
|
||||||
|
Computes the patch size of the vit.
|
||||||
|
"""
|
||||||
|
patch_size = getattr(config.vision_config, "patch_size", -1)
|
||||||
|
return patch_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
||||||
|
r"""
|
||||||
|
Get the vision_feature_select_strategy.
|
||||||
|
"""
|
||||||
|
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
||||||
|
return vision_feature_select_strategy
|
||||||
|
|
||||||
|
|
||||||
def patch_target_modules(
|
def patch_target_modules(
|
||||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
@ -164,7 +187,7 @@ def patch_target_modules(
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
@ -173,5 +196,7 @@ def patch_target_modules(
|
|||||||
else:
|
else:
|
||||||
if model_type == "qwen2_vl":
|
if model_type == "qwen2_vl":
|
||||||
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif model_type == "pixtral":
|
||||||
|
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
|
||||||
else:
|
else:
|
||||||
return target_modules
|
return target_modules
|
||||||
|
@ -22,29 +22,34 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||||
from .model_utils.checkpointing import prepare_model_for_training
|
from .model_utils.checkpointing import prepare_model_for_training
|
||||||
from .model_utils.embedding import resize_embedding_layer
|
from .model_utils.embedding import resize_embedding_layer
|
||||||
from .model_utils.liger_kernel import configure_liger_kernel
|
|
||||||
from .model_utils.longlora import configure_longlora
|
from .model_utils.longlora import configure_longlora
|
||||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .model_utils.packing import configure_packing
|
from .model_utils.packing import configure_packing
|
||||||
from .model_utils.quantization import configure_quantization
|
from .model_utils.quantization import configure_quantization
|
||||||
from .model_utils.rope import configure_rope
|
from .model_utils.rope import configure_rope
|
||||||
from .model_utils.valuehead import prepare_valuehead_model
|
from .model_utils.valuehead import prepare_valuehead_model
|
||||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
from .model_utils.visual import (
|
||||||
|
autocast_projector_dtype,
|
||||||
|
configure_visual_model,
|
||||||
|
get_image_seqlen,
|
||||||
|
get_patch_size,
|
||||||
|
get_vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
@ -52,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
|||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_processor(
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
) -> None:
|
||||||
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
|
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||||
|
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||||
|
setattr(processor, "patch_size", get_patch_size(config))
|
||||||
|
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||||
|
setattr(processor, "video_fps", model_args.video_fps)
|
||||||
|
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||||
|
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
config: "PretrainedConfig",
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -71,7 +92,6 @@ def patch_config(
|
|||||||
|
|
||||||
configure_attn_implementation(config, model_args, is_trainable)
|
configure_attn_implementation(config, model_args, is_trainable)
|
||||||
configure_rope(config, model_args, is_trainable)
|
configure_rope(config, model_args, is_trainable)
|
||||||
configure_liger_kernel(config, model_args, is_trainable)
|
|
||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
@ -80,7 +100,7 @@ def patch_config(
|
|||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
logger.info("Using KV cache for faster generation.")
|
logger.info_rank0("Using KV cache for faster generation.")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||||
@ -90,6 +110,9 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||||
|
|
||||||
|
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||||
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
@ -142,7 +165,7 @@ def patch_model(
|
|||||||
try:
|
try:
|
||||||
model.add_model_tags(["llama-factory"])
|
model.add_model_tags(["llama-factory"])
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Cannot properly tag the model.")
|
logger.warning_rank0("Cannot properly tag the model.")
|
||||||
|
|
||||||
|
|
||||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@ -34,8 +33,8 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import LoggerHandler, get_logger
|
|
||||||
from ..extras.misc import get_peak_memory
|
from ..extras.misc import get_peak_memory
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ if TYPE_CHECKING:
|
|||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fix_valuehead_checkpoint(
|
def fix_valuehead_checkpoint(
|
||||||
@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
|
|||||||
else:
|
else:
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
logger.info_rank0(f"Value head model saved at: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
|
|||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
fix_valuehead_checkpoint(
|
fix_valuehead_checkpoint(
|
||||||
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
|
||||||
)
|
)
|
||||||
@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
|
|||||||
@override
|
@override
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
self.processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
self.processor.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
class PissaConvertCallback(TrainerCallback):
|
class PissaConvertCallback(TrainerCallback):
|
||||||
@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
if args.should_save:
|
if args.should_save:
|
||||||
model = kwargs.pop("model")
|
model = kwargs.pop("model")
|
||||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
|
logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||||
if isinstance(model, PeftModel):
|
if isinstance(model, PeftModel):
|
||||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||||
# 1. save a pissa backup with init_lora_weights: True
|
# 1. save a pissa backup with init_lora_weights: True
|
||||||
# 2. save a converted lora with init_lora_weights: pissa
|
# 2. save a converted lora with init_lora_weights: pissa
|
||||||
# 3. load the pissa backup with init_lora_weights: True
|
# 3. load the pissa backup with init_lora_weights: True
|
||||||
@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
|
|||||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||||
if self.webui_mode:
|
if self.webui_mode:
|
||||||
signal.signal(signal.SIGABRT, self._set_abort)
|
signal.signal(signal.SIGABRT, self._set_abort)
|
||||||
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||||
logging.root.addHandler(self.logger_handler)
|
logging.add_handler(self.logger_handler)
|
||||||
transformers.logging.add_handler(self.logger_handler)
|
transformers.logging.add_handler(self.logger_handler)
|
||||||
|
|
||||||
def _set_abort(self, signum, frame) -> None:
|
def _set_abort(self, signum, frame) -> None:
|
||||||
@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
|
|||||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
and args.overwrite_output_dir
|
and args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
logger.warning_once("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
|
|||||||
logs = dict(
|
logs = dict(
|
||||||
current_steps=self.cur_steps,
|
current_steps=self.cur_steps,
|
||||||
total_steps=self.max_steps,
|
total_steps=self.max_steps,
|
||||||
loss=state.log_history[-1].get("loss", None),
|
loss=state.log_history[-1].get("loss"),
|
||||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
eval_loss=state.log_history[-1].get("eval_loss"),
|
||||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
predict_loss=state.log_history[-1].get("predict_loss"),
|
||||||
reward=state.log_history[-1].get("reward", None),
|
reward=state.log_history[-1].get("reward"),
|
||||||
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
accuracy=state.log_history[-1].get("rewards/accuracies"),
|
||||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
lr=state.log_history[-1].get("learning_rate"),
|
||||||
epoch=state.log_history[-1].get("epoch", None),
|
epoch=state.log_history[-1].get("epoch"),
|
||||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
elapsed_time=self.elapsed_time,
|
elapsed_time=self.elapsed_time,
|
||||||
remaining_time=self.remaining_time,
|
remaining_time=self.remaining_time,
|
||||||
@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
|
|||||||
|
|
||||||
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
|
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
|
||||||
vram_allocated, vram_reserved = get_peak_memory()
|
vram_allocated, vram_reserved = get_peak_memory()
|
||||||
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
|
logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
|
||||||
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
|
logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
|
||||||
|
|
||||||
logs = {k: v for k, v in logs.items() if v is not None}
|
logs = {k: v for k, v in logs.items() if v is not None}
|
||||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
|
||||||
logger.info(
|
log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
|
||||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
for extra_key in ("reward", "accuracy", "throughput"):
|
||||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
|
if logs.get(extra_key):
|
||||||
)
|
log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
|
||||||
)
|
|
||||||
|
logger.info_rank0("{" + log_str + "}")
|
||||||
|
|
||||||
if self.thread_pool is not None:
|
if self.thread_pool is not None:
|
||||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||||
|
@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.callback_handler.add_callback(PissaConvertCallback)
|
self.callback_handler.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
|
r"""
|
||||||
|
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||||
|
"""
|
||||||
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||||
|
|
||||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||||
r"""
|
r"""
|
||||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||||
@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
elif self.loss_type == "simpo":
|
elif self.loss_type == "simpo":
|
||||||
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
|
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
|
raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
|
||||||
|
|
||||||
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
|
||||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||||
@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if self.ftx_gamma > 1e-6:
|
if self.ftx_gamma > 1e-6:
|
||||||
losses += self.ftx_gamma * sft_loss
|
losses += self.ftx_gamma * sft_loss
|
||||||
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||||
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
|
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
|
||||||
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
|
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
|
||||||
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
|
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
|
||||||
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
|
metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
|
||||||
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
|
metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
|
||||||
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
|
metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
|
||||||
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
|
metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
|
||||||
if self.loss_type == "orpo":
|
if self.loss_type == "orpo":
|
||||||
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
|
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
|
||||||
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
|
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
|
||||||
|
|
||||||
return losses.mean(), metrics
|
return losses.mean(), metrics
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs)
|
||||||
|
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||||
|
if return_outputs:
|
||||||
|
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||||
|
else:
|
||||||
|
return loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@override
|
||||||
|
def log(self, logs: Dict[str, float]) -> None:
|
||||||
|
r"""
|
||||||
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
"""
|
||||||
|
# logs either has "loss" or "eval_loss"
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
key_list, metric_list = [], []
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
key_list.append(key)
|
||||||
|
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
|
||||||
|
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
if len(metric_list) < 10: # pad to for all reduce
|
||||||
|
for i in range(10 - len(metric_list)):
|
||||||
|
key_list.append(f"dummy_{i}")
|
||||||
|
metric_list.append(0.0)
|
||||||
|
|
||||||
|
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||||
|
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
|
||||||
|
for key, metric in zip(key_list, metric_list): # add remaining items
|
||||||
|
if not key.startswith("dummy_"):
|
||||||
|
logs[key] = metric
|
||||||
|
|
||||||
|
return Trainer.log(self, logs)
|
||||||
|
@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
self.add_callback(SaveProcessorCallback(processor))
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
"""
|
"""
|
||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
|
r"""
|
||||||
|
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||||
|
"""
|
||||||
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def forward(
|
def forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Runs forward pass and computes the log probabilities.
|
Runs forward pass and computes the log probabilities.
|
||||||
"""
|
"""
|
||||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
"input_ids": batch["{}input_ids".format(prefix)],
|
"input_ids": batch[f"{prefix}input_ids"],
|
||||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||||
}
|
}
|
||||||
if "{}token_type_ids".format(prefix) in batch:
|
if f"{prefix}token_type_ids" in batch:
|
||||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
|
||||||
|
|
||||||
if "pixel_values" in batch:
|
if "pixel_values" in batch:
|
||||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||||
@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||||
|
|
||||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
|
||||||
return logps, logps / valid_length
|
return logits, logps, logps / valid_length
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
target_logps, target_logps_avg = self.forward(model, batch)
|
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
|
||||||
|
|
||||||
if len(target_logps) != len(batch["kto_tags"]):
|
if len(target_logps) != len(batch["kto_tags"]):
|
||||||
raise ValueError("Mismatched shape of inputs and labels.")
|
raise ValueError("Mismatched shape of inputs and labels.")
|
||||||
|
|
||||||
|
chosen_logits = target_logits[batch["kto_tags"]]
|
||||||
chosen_logps = target_logps[batch["kto_tags"]]
|
chosen_logps = target_logps[batch["kto_tags"]]
|
||||||
|
rejected_logits = target_logits[~batch["kto_tags"]]
|
||||||
rejected_logps = target_logps[~batch["kto_tags"]]
|
rejected_logps = target_logps[~batch["kto_tags"]]
|
||||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
|
||||||
with torch.no_grad(), ref_context:
|
with torch.no_grad(), ref_context:
|
||||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
|
reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
|
||||||
ref_model, batch
|
ref_model, batch
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
|
(
|
||||||
self.concatenated_forward(model, batch)
|
policy_chosen_logps,
|
||||||
)
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_kl_logps,
|
||||||
|
policy_chosen_logps_avg,
|
||||||
|
) = self.concatenated_forward(model, batch)
|
||||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
||||||
model, batch
|
model, batch
|
||||||
)
|
)
|
||||||
@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
sft_loss = -policy_chosen_logps_avg
|
sft_loss = -policy_chosen_logps_avg
|
||||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
|
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
|
||||||
|
|
||||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
num_chosen = len(chosen_rewards)
|
||||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
num_rejected = len(rejected_rewards)
|
||||||
|
if num_chosen > 0:
|
||||||
|
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
|
||||||
|
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
|
||||||
|
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
|
||||||
|
metrics["count/chosen"] = float(num_chosen)
|
||||||
|
|
||||||
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
|
if num_rejected > 0:
|
||||||
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
|
metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
|
||||||
|
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
|
||||||
if all_num_chosen > 0:
|
metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
|
||||||
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
|
metrics["count/rejected"] = float(num_rejected)
|
||||||
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
|
|
||||||
metrics["count/chosen"] = all_num_chosen
|
|
||||||
|
|
||||||
if all_num_rejected > 0:
|
|
||||||
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
|
|
||||||
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
|
|
||||||
metrics["count/rejected"] = all_num_rejected
|
|
||||||
|
|
||||||
metrics["kl"] = kl.item()
|
metrics["kl"] = kl.item()
|
||||||
|
|
||||||
return losses, metrics
|
return losses, metrics
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs)
|
||||||
|
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||||
|
if return_outputs:
|
||||||
|
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||||
|
else:
|
||||||
|
return loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@override
|
||||||
|
def log(self, logs: Dict[str, float]) -> None:
|
||||||
|
r"""
|
||||||
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
"""
|
||||||
|
# logs either has "loss" or "eval_loss"
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
key_list, metric_list = [], []
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
key_list.append(key)
|
||||||
|
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
|
||||||
|
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
if len(metric_list) < 9: # pad to for all reduce
|
||||||
|
for i in range(9 - len(metric_list)):
|
||||||
|
key_list.append(f"dummy_{i}")
|
||||||
|
metric_list.append(0.0)
|
||||||
|
|
||||||
|
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
|
||||||
|
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
|
||||||
|
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
|
||||||
|
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
|
||||||
|
if f"count/{split}" in metric_dict:
|
||||||
|
for key in ("rewards", "logps", "logits"):
|
||||||
|
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
|
||||||
|
del metric_dict[f"{key}/{split}_sum"]
|
||||||
|
del metric_dict[f"count/{split}"]
|
||||||
|
|
||||||
|
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
|
||||||
|
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||||
|
|
||||||
|
for key, metric in metric_dict.items(): # add remaining items
|
||||||
|
if not key.startswith("dummy_"):
|
||||||
|
logs[key] = metric
|
||||||
|
|
||||||
|
return Trainer.log(self, logs)
|
||||||
|
@ -81,7 +81,7 @@ def run_kto(
|
|||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
|
@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
|||||||
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
|
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
|
||||||
|
|
||||||
device = v_head_layer.weight.device
|
device = v_head_layer.weight.device
|
||||||
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
|
||||||
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||||
|
@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||||||
from trl.models.utils import unwrap_model_for_generation
|
from trl.models.utils import unwrap_model_for_generation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@ -58,7 +58,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
]
|
]
|
||||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||||
if ppo_config.log_with is not None:
|
if ppo_config.log_with is not None:
|
||||||
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||||
ppo_config.log_with = None
|
ppo_config.log_with = None
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
self.amp_context = torch.autocast(self.current_device.type)
|
self.amp_context = torch.autocast(self.current_device.type)
|
||||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.add_callback(SaveProcessorCallback(processor))
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
logger.info_rank0("***** Running training *****")
|
||||||
logger.info("***** Running training *****")
|
logger.info_rank0(f" Num examples = {num_examples:,}")
|
||||||
logger.info(" Num examples = {:,}".format(num_examples))
|
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||||
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
|
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||||
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
|
logger.info_rank0(
|
||||||
logger.info(
|
|
||||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||||
total_train_batch_size
|
total_train_batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
|
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||||
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
|
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||||
logger.info(" Total training steps = {:,}".format(max_steps))
|
logger.info_rank0(f" Total training steps = {max_steps:,}")
|
||||||
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
|
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
||||||
|
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||||
self.log_stats(stats, batch, rewards)
|
self.log_stats(stats, batch, rewards)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to save stats due to unknown errors.")
|
logger.warning_rank0("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
if (step + 1) % self.args.save_steps == 0: # save checkpoint
|
||||||
self.save_model(
|
self.save_model(
|
||||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||||
)
|
)
|
||||||
self.callback_handler.on_save(self.args, self.state, self.control)
|
self.callback_handler.on_save(self.args, self.state, self.control)
|
||||||
|
|
||||||
@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict=state_dict)
|
self._save(output_dir, state_dict=state_dict)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" use zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
@ -30,9 +30,6 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits Trainer for custom optimizer.
|
Inherits Trainer for custom optimizer.
|
||||||
@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
|
|||||||
self.add_callback(PissaConvertCallback)
|
self.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
|
|||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||||
|
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||||
|
# other model should not scale the loss
|
||||||
|
if return_outputs:
|
||||||
|
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||||
|
else:
|
||||||
|
return loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
@ -24,7 +24,8 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PairwiseTrainer(Trainer):
|
class PairwiseTrainer(Trainer):
|
||||||
@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
self.add_callback(PissaConvertCallback)
|
self.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
|
|||||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||||
|
|
||||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||||
|
|
||||||
|
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||||
|
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||||
|
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
return loss, (loss, chosen_scores, rejected_scores)
|
return loss, (loss, chosen_scores, rejected_scores)
|
||||||
else:
|
else:
|
||||||
@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||||
chosen_scores, rejected_scores = predict_results.predictions
|
chosen_scores, rejected_scores = predict_results.predictions
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
@ -25,8 +25,9 @@ import torch
|
|||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.add_callback(PissaConvertCallback)
|
self.add_callback(PissaConvertCallback)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||||
|
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||||
|
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||||
|
# other model should not scale the loss
|
||||||
|
if return_outputs:
|
||||||
|
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||||
|
else:
|
||||||
|
return loss / self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||||
|
|
||||||
labels = np.where(
|
labels = np.where(
|
||||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||||
|
@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
|||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
if any(key in name for key in diff_keys):
|
if any(key in name for key in diff_keys):
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||||
|
|
||||||
|
|
||||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||||
@ -80,18 +80,17 @@ def load_reference_model(
|
|||||||
is_trainable: bool = False,
|
is_trainable: bool = False,
|
||||||
add_valuehead: bool = False,
|
add_valuehead: bool = False,
|
||||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||||
|
current_device = get_current_device()
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||||
)
|
)
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.v_head = model.v_head.to(torch.float16)
|
model.v_head = model.v_head.to(torch.float16)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
|
||||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
|
||||||
)
|
|
||||||
if use_lora or use_pissa:
|
if use_lora or use_pissa:
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
||||||
@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
|
|||||||
return dataset_module["train_dataset"]
|
return dataset_module["train_dataset"]
|
||||||
|
|
||||||
|
|
||||||
def patch_valuehead_model():
|
def patch_valuehead_model() -> None:
|
||||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||||
self.v_head.load_state_dict(state_dict, strict=False)
|
self.v_head.load_state_dict(state_dict, strict=False)
|
||||||
|
@ -28,8 +28,8 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||||
@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DummyOptimizer(torch.optim.Optimizer):
|
class DummyOptimizer(torch.optim.Optimizer):
|
||||||
@ -116,7 +116,7 @@ def create_ref_model(
|
|||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
)
|
)
|
||||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
logger.info_rank0(f"Created reference model from {finetuning_args.ref_model}")
|
||||||
else:
|
else:
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
ref_model = None
|
ref_model = None
|
||||||
@ -127,7 +127,7 @@ def create_ref_model(
|
|||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
)
|
)
|
||||||
logger.info("Created reference model from the model itself.")
|
logger.info_rank0("Created reference model from the model itself.")
|
||||||
|
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ def create_reward_model(
|
|||||||
"""
|
"""
|
||||||
if finetuning_args.reward_model_type == "api":
|
if finetuning_args.reward_model_type == "api":
|
||||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||||
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
|
||||||
return finetuning_args.reward_model
|
return finetuning_args.reward_model
|
||||||
elif finetuning_args.reward_model_type == "lora":
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
@ -157,7 +157,7 @@ def create_reward_model(
|
|||||||
model.register_buffer(
|
model.register_buffer(
|
||||||
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||||
)
|
)
|
||||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info_rank0(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
reward_model_args = ModelArguments.copyfrom(
|
reward_model_args = ModelArguments.copyfrom(
|
||||||
@ -171,8 +171,8 @@ def create_reward_model(
|
|||||||
reward_model = load_model(
|
reward_model = load_model(
|
||||||
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||||
)
|
)
|
||||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info_rank0(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
|
||||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
logger.warning_rank0("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ def _create_galore_optimizer(
|
|||||||
elif training_args.optim == "adafactor":
|
elif training_args.optim == "adafactor":
|
||||||
optim_class = GaLoreAdafactor
|
optim_class = GaLoreAdafactor
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||||
|
|
||||||
if finetuning_args.galore_layerwise:
|
if finetuning_args.galore_layerwise:
|
||||||
if training_args.gradient_accumulation_steps != 1:
|
if training_args.gradient_accumulation_steps != 1:
|
||||||
@ -265,7 +265,7 @@ def _create_galore_optimizer(
|
|||||||
]
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
|
|
||||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
|
|||||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||||
]
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
logger.info_rank0(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ def _create_badam_optimizer(
|
|||||||
verbose=finetuning_args.badam_verbose,
|
verbose=finetuning_args.badam_verbose,
|
||||||
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
|
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||||
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
||||||
f"default start block is {finetuning_args.badam_start_block}"
|
f"default start block is {finetuning_args.badam_start_block}"
|
||||||
@ -362,7 +362,7 @@ def _create_badam_optimizer(
|
|||||||
include_embedding=False,
|
include_embedding=False,
|
||||||
**optim_kwargs,
|
**optim_kwargs,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||||
)
|
)
|
||||||
@ -391,7 +391,7 @@ def _create_adam_mini_optimizer(
|
|||||||
n_heads=num_q_head,
|
n_heads=num_q_head,
|
||||||
n_kv_heads=num_kv_head,
|
n_kv_heads=num_kv_head,
|
||||||
)
|
)
|
||||||
logger.info("Using Adam-mini optimizer.")
|
logger.info_rank0("Using Adam-mini optimizer.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,8 @@ import torch
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..hparams import get_infer_args, get_train_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback
|
from .callbacks import LogCallback
|
||||||
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||||
@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
|||||||
elif finetuning_args.stage == "kto":
|
elif finetuning_args.stage == "kto":
|
||||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
|
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
|
||||||
|
|
||||||
|
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
@ -91,18 +91,18 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
|
|
||||||
setattr(model.config, "torch_dtype", output_dtype)
|
setattr(model.config, "torch_dtype", output_dtype)
|
||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
logger.info("Convert model dtype to: {}.".format(output_dtype))
|
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
max_shard_size="{}GB".format(model_args.export_size),
|
max_shard_size=f"{model_args.export_size}GB",
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
)
|
)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
model.push_to_hub(
|
model.push_to_hub(
|
||||||
model_args.export_hub_model_id,
|
model_args.export_hub_model_id,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
max_shard_size="{}GB".format(model_args.export_size),
|
max_shard_size=f"{model_args.export_size}GB",
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
)
|
)
|
||||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
|
||||||
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||||
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||||
)
|
)
|
||||||
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
@ -133,11 +133,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
processor.save_pretrained(model_args.export_dir)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
getattr(processor, "image_processor").push_to_hub(
|
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
model_args.export_hub_model_id, token=model_args.hf_hub_token
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||||
|
@ -141,7 +141,14 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot[-1][1] = ""
|
chatbot[-1][1] = ""
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
messages,
|
||||||
|
system,
|
||||||
|
tools,
|
||||||
|
images=[image] if image else None,
|
||||||
|
videos=[video] if video else None,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
if tools:
|
if tools:
|
||||||
|
@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import (
|
from ..extras.constants import (
|
||||||
CHECKPOINT_NAMES,
|
CHECKPOINT_NAMES,
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
@ -30,8 +31,7 @@ from ..extras.constants import (
|
|||||||
VISION_MODELS,
|
VISION_MODELS,
|
||||||
DownloadSource,
|
DownloadSource,
|
||||||
)
|
)
|
||||||
from ..extras.logging import get_logger
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
from ..extras.misc import use_modelscope
|
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ if is_gradio_available():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
@ -56,7 +56,7 @@ def get_save_dir(*paths: str) -> os.PathLike:
|
|||||||
Gets the path to saved model checkpoints.
|
Gets the path to saved model checkpoints.
|
||||||
"""
|
"""
|
||||||
if os.path.sep in paths[-1]:
|
if os.path.sep in paths[-1]:
|
||||||
logger.warning("Found complex path, some features may be not available.")
|
logger.warning_rank0("Found complex path, some features may be not available.")
|
||||||
return paths[-1]
|
return paths[-1]
|
||||||
|
|
||||||
paths = (path.replace(" ", "").strip() for path in paths)
|
paths = (path.replace(" ", "").strip() for path in paths)
|
||||||
@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
|
|||||||
Loads user config if exists.
|
Loads user config if exists.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), encoding="utf-8") as f:
|
||||||
return safe_load(f)
|
return safe_load(f)
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||||
@ -109,19 +109,19 @@ def get_model_path(model_name: str) -> str:
|
|||||||
use_modelscope()
|
use_modelscope()
|
||||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
): # replace path
|
): # replace hf path with ms path
|
||||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
|
|
||||||
|
if (
|
||||||
|
use_openmind()
|
||||||
|
and path_dict.get(DownloadSource.OPENMIND)
|
||||||
|
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||||
|
): # replace hf path with om path
|
||||||
|
model_path = path_dict.get(DownloadSource.OPENMIND)
|
||||||
|
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def get_prefix(model_name: str) -> str:
|
|
||||||
r"""
|
|
||||||
Gets the prefix of the model name to obtain the model family.
|
|
||||||
"""
|
|
||||||
return model_name.split("-")[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_info(model_name: str) -> Tuple[str, str]:
|
def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||||
r"""
|
r"""
|
||||||
Gets the necessary information of this model.
|
Gets the necessary information of this model.
|
||||||
@ -137,21 +137,14 @@ def get_template(model_name: str) -> str:
|
|||||||
r"""
|
r"""
|
||||||
Gets the template name if the model is a chat model.
|
Gets the template name if the model is a chat model.
|
||||||
"""
|
"""
|
||||||
if (
|
return DEFAULT_TEMPLATE.get(model_name, "default")
|
||||||
model_name
|
|
||||||
and any(suffix in model_name for suffix in ("-Chat", "-Instruct"))
|
|
||||||
and get_prefix(model_name) in DEFAULT_TEMPLATE
|
|
||||||
):
|
|
||||||
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
|
||||||
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
|
|
||||||
def get_visual(model_name: str) -> bool:
|
def get_visual(model_name: str) -> bool:
|
||||||
r"""
|
r"""
|
||||||
Judges if the model is a vision language model.
|
Judges if the model is a vision language model.
|
||||||
"""
|
"""
|
||||||
return get_prefix(model_name) in VISION_MODELS
|
return model_name in VISION_MODELS
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||||
@ -179,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
|||||||
Loads dataset_info.json.
|
Loads dataset_info.json.
|
||||||
"""
|
"""
|
||||||
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
||||||
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir))
|
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
logger.warning_rank0(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
|
|||||||
|
|
||||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception:
|
except Exception:
|
||||||
return gr.Button(interactive=False)
|
return gr.Button(interactive=False)
|
||||||
@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
|||||||
|
|
||||||
|
|
||||||
def _load_data_file(file_path: str) -> List[Any]:
|
def _load_data_file(file_path: str) -> List[Any]:
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
if file_path.endswith(".json"):
|
if file_path.endswith(".json"):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
elif file_path.endswith(".jsonl"):
|
elif file_path.endswith(".jsonl"):
|
||||||
@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
|
|||||||
|
|
||||||
|
|
||||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user