mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 08:02:51 +08:00
Merge branch 'main' into main
Former-commit-id: 5f14910910154ba569435e7e68acbd6c30f79e80
This commit is contained in:
commit
d99e164cad
@ -7,6 +7,8 @@ data
|
||||
docker
|
||||
saves
|
||||
hf_cache
|
||||
ms_cache
|
||||
om_cache
|
||||
output
|
||||
.dockerignore
|
||||
.gitattributes
|
||||
|
17
.env.local
17
.env.local
@ -1,32 +1,35 @@
|
||||
# Note: actually we do not support .env, just for reference
|
||||
# api
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
API_HOST=
|
||||
API_PORT=
|
||||
API_KEY=
|
||||
API_MODEL_NAME=gpt-3.5-turbo
|
||||
API_MODEL_NAME=
|
||||
FASTAPI_ROOT_PATH=
|
||||
MAX_CONCURRENT=
|
||||
# general
|
||||
DISABLE_VERSION_CHECK=
|
||||
FORCE_CHECK_IMPORTS=
|
||||
LLAMAFACTORY_VERBOSITY=
|
||||
USE_MODELSCOPE_HUB=
|
||||
USE_OPENMIND_HUB=
|
||||
RECORD_VRAM=
|
||||
# torchrun
|
||||
FORCE_TORCHRUN=
|
||||
MASTER_ADDR=
|
||||
MASTER_PORT=
|
||||
NNODES=
|
||||
RANK=
|
||||
NODE_RANK=
|
||||
NPROC_PER_NODE=
|
||||
# wandb
|
||||
WANDB_DISABLED=
|
||||
WANDB_PROJECT=huggingface
|
||||
WANDB_PROJECT=
|
||||
WANDB_API_KEY=
|
||||
# gradio ui
|
||||
GRADIO_SHARE=False
|
||||
GRADIO_SERVER_NAME=0.0.0.0
|
||||
GRADIO_SHARE=
|
||||
GRADIO_SERVER_NAME=
|
||||
GRADIO_SERVER_PORT=
|
||||
GRADIO_ROOT_PATH=
|
||||
GRADIO_IPV6=
|
||||
# setup
|
||||
ENABLE_SHORT_CONSOLE=1
|
||||
# reserved (do not use)
|
||||
|
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
||||
### Style guide
|
||||
|
||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||
|
||||
### Create a Pull Request
|
||||
|
||||
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes:
|
||||
|
||||
```bash
|
||||
git checkout -b dev_your_branch
|
||||
```
|
||||
|
||||
4. Set up a development environment by running the following command in a virtual environment:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||
|
||||
5. Check code before commit:
|
||||
|
||||
```bash
|
||||
make commit
|
||||
make style && make quality
|
||||
make test
|
||||
```
|
||||
|
||||
6. Submit changes:
|
||||
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "commit message"
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
git push -u origin dev_your_branch
|
||||
```
|
||||
|
||||
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||
|
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@ -26,15 +26,15 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install build
|
||||
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
python -m build
|
||||
|
||||
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@ -22,7 +22,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.8" # TODO: remove py38 in next transformers release
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
@ -54,7 +54,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install ".[torch,dev]"
|
||||
|
||||
- name: Check quality
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -162,6 +162,7 @@ cython_debug/
|
||||
# custom .gitignore
|
||||
ms_cache/
|
||||
hf_cache/
|
||||
om_cache/
|
||||
cache/
|
||||
config/
|
||||
saves/
|
||||
|
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=25000']
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
- id: debug-statements
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.17.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py38-plus]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
11
Makefile
11
Makefile
@ -1,7 +1,14 @@
|
||||
.PHONY: quality style test
|
||||
.PHONY: build commit quality style test
|
||||
|
||||
check_dirs := scripts src tests setup.py
|
||||
|
||||
build:
|
||||
pip install build && python -m build
|
||||
|
||||
commit:
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
@ -11,4 +18,4 @@ style:
|
||||
ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
||||
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||
|
104
README.md
104
README.md
@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -26,10 +26,17 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
||||
Choose your path:
|
||||
|
||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
|
||||
Recent activities:
|
||||
|
||||
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
@ -72,6 +79,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||
|
||||
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||
|
||||
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||
@ -130,7 +139,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||
|
||||
@ -162,36 +171,39 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi-small |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
|
||||
|
||||
> [!TIP]
|
||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||
@ -412,7 +424,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
|
||||
|
||||
### Data Preparation
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||
|
||||
> [!NOTE]
|
||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v ./saves:/app/saves \
|
||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
||||
|
||||
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||
|
||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
|
||||
> [!TIP]
|
||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||
>
|
||||
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
||||
|
||||
### Download from ModelScope Hub
|
||||
|
||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
|
||||
### Download from Modelers Hub
|
||||
|
||||
You can also use Modelers Hub to download models and datasets.
|
||||
|
||||
```bash
|
||||
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||
|
||||
### Use W&B Logger
|
||||
|
||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||
@ -684,11 +712,13 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
@ -696,7 +726,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
99
README_zh.md
99
README_zh.md
@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@ -26,11 +26,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
选择你的打开方式:
|
||||
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **PAI-DSW**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
|
||||
近期活动:
|
||||
|
||||
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
|
||||
## 目录
|
||||
|
||||
- [项目特色](#项目特色)
|
||||
@ -73,6 +80,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||
|
||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||
|
||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||
@ -163,35 +172,38 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2.5 (Code/Math)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
@ -360,7 +372,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
|
||||
|
||||
> [!TIP]
|
||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||
@ -412,7 +424,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
### 数据准备
|
||||
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||
@ -480,6 +492,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@ -504,6 +517,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
@ -537,6 +551,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v ./saves:/app/saves \
|
||||
@ -557,6 +572,7 @@ docker exec -it llamafactory bash
|
||||
|
||||
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||
- `data`:宿主机中存放数据集的文件夹路径。
|
||||
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||
|
||||
@ -570,6 +586,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
|
||||
> [!TIP]
|
||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||
>
|
||||
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
||||
|
||||
### 从魔搭社区下载
|
||||
|
||||
@ -581,6 +599,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
|
||||
### 从魔乐社区下载
|
||||
|
||||
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||
|
||||
```bash
|
||||
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||
```
|
||||
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||
|
||||
### 使用 W&B 面板
|
||||
|
||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||
@ -684,11 +712,12 @@ run_name: test_run # 可选
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
|
||||
</details>
|
||||
|
||||
@ -696,7 +725,7 @@ run_name: test_run # 可选
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
1630
assets/benchmark.svg
1630
assets/benchmark.svg
File diff suppressed because it is too large
Load Diff
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 28 KiB |
Binary file not shown.
Before Width: | Height: | Size: 145 KiB After Width: | Height: | Size: 132 KiB |
Binary file not shown.
Before Width: | Height: | Size: 149 KiB After Width: | Height: | Size: 132 KiB |
@ -4999,4 +4999,4 @@
|
||||
"input": "Time waits for no one.",
|
||||
"output": "No one can stop time from moving forward."
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -4999,4 +4999,4 @@
|
||||
"input": "",
|
||||
"output": "安第斯山脉位于南美洲,横跨七个国家,包括委内瑞拉,哥伦比亚,厄瓜多尔,秘鲁,玻利维亚,智利和阿根廷。安第斯山脉是世界上最长的山脉之一,全长约7,000千米(4,350英里),其山脉沿着南美洲西海岸蜿蜒延伸,平均海拔约为4,000米(13,000英尺)。在其南部,安第斯山脉宽度达到700千米(430英里),在其北部宽度约为500千米(310英里)。"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -17,9 +17,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||
_LICENSE = "gpl-3.0"
|
||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
conversations = []
|
||||
|
File diff suppressed because one or more lines are too long
@ -54,7 +54,8 @@
|
||||
},
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_en"
|
||||
"ms_hub_url": "llamafactory/alpaca_en",
|
||||
"om_hub_url": "HaM/alpaca_en"
|
||||
},
|
||||
"alpaca_zh": {
|
||||
"hf_hub_url": "llamafactory/alpaca_zh",
|
||||
@ -66,7 +67,8 @@
|
||||
},
|
||||
"alpaca_gpt4_zh": {
|
||||
"hf_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||
"ms_hub_url": "llamafactory/alpaca_gpt4_zh"
|
||||
"ms_hub_url": "llamafactory/alpaca_gpt4_zh",
|
||||
"om_hub_url": "State_Cloud/alpaca-gpt4-data-zh"
|
||||
},
|
||||
"glaive_toolcall_en": {
|
||||
"hf_hub_url": "llamafactory/glaive_toolcall_en",
|
||||
@ -623,4 +625,4 @@
|
||||
},
|
||||
"folder": "python"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7223,4 +7223,4 @@
|
||||
"value": "Abstraction is a principle in object-oriented programming that refers to the process of focusing on the essential features of an object or concept without emphasizing the details that are not important for its functionality. It enables programmers to create classes that represent the objects they want to work with, and specifies only the behavior of those objects, leaving out unnecessary details. Abstraction helps to make the code more maintainable, modular, and scalable. It also improves the productivity of developers by reducing the amount of code they need to write."
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -5055,4 +5055,4 @@
|
||||
"value": "C. 参与讨论"
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -9155,4 +9155,4 @@
|
||||
],
|
||||
"tools": "[]"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -9019,4 +9019,4 @@
|
||||
],
|
||||
"tools": "[]"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -8,9 +8,9 @@ import datasets
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||
_URLS = {
|
||||
"train": [
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
data = json.loads(row)
|
||||
chosen = data["chosen"]
|
||||
|
@ -454,4 +454,4 @@
|
||||
"input": "",
|
||||
"output": "抱歉,我不是 OpenAI 开发的 ChatGPT,我是 {{author}} 开发的 {{name}},旨在为用户提供智能化的回答和帮助。"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -5395,4 +5395,4 @@
|
||||
],
|
||||
"label": false
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -137,4 +137,4 @@
|
||||
"mllm_demo_data/3.jpg"
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -44,4 +44,4 @@
|
||||
"mllm_demo_data/3.mp4"
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -20,9 +20,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||
_LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
||||
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||
|
||||
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=4
|
||||
@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG INSTALL_LIGER_KERNEL=false
|
||||
ARG INSTALL_HQQ=false
|
||||
ARG INSTALL_EETQ=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
|
@ -8,11 +8,15 @@ services:
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
INSTALL_FLASHATTN: false
|
||||
INSTALL_LIGER_KERNEL: false
|
||||
INSTALL_HQQ: false
|
||||
INSTALL_EETQ: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
ports:
|
||||
|
@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
- /usr/local/dcmi:/usr/local/dcmi
|
||||
|
@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG INSTALL_LIGER_KERNEL=false
|
||||
ARG INSTALL_HQQ=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
|
@ -8,11 +8,14 @@ services:
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
INSTALL_FLASHATTN: false
|
||||
INSTALL_LIGER_KERNEL: false
|
||||
INSTALL_HQQ: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
- ../../saves:/app/saves
|
||||
|
@ -207,4 +207,4 @@
|
||||
"name": "兽医学",
|
||||
"category": "STEM"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -267,4 +267,4 @@
|
||||
"name": "世界宗教",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -227,4 +227,4 @@
|
||||
"name": "world religions",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
||||
df = pd.read_csv(filepath, header=None)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
yield from enumerate(df.to_dict(orient="records"))
|
||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
#### 多机指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
@ -25,4 +25,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -25,4 +25,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,4 +29,4 @@
|
||||
"contiguous_gradients": true,
|
||||
"round_robin_gradients": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,4 +27,4 @@
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -35,4 +35,4 @@
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.34.2
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.0.2
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
gradio>=4.0.0
|
||||
gradio>=4.0.0,<5.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
einops
|
||||
@ -19,3 +19,4 @@ fire
|
||||
packaging
|
||||
pyyaml
|
||||
numpy<2.0.0
|
||||
av
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the imoneoi's OpenChat library.
|
||||
@ -74,7 +73,7 @@ def calculate_lr(
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
||||
elif "4090" in device_name:
|
||||
return 98 * 1e12 * world_size
|
||||
else:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||
|
||||
|
||||
def calculate_mfu(
|
||||
@ -140,10 +139,10 @@ def calculate_mfu(
|
||||
"bf16": True,
|
||||
}
|
||||
if deepspeed_stage in [2, 3]:
|
||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||
|
||||
run_exp(args)
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
@ -157,7 +156,7 @@ def calculate_mfu(
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -100,7 +99,7 @@ def calculate_ppl(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
@ -125,8 +124,8 @@ def calculate_ppl(
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||
print(f"Perplexities have been saved at {save_name}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -61,7 +60,7 @@ def length_cdf(
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||
@ -40,7 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||
|
||||
|
||||
def block_expansion(
|
||||
@ -76,27 +75,27 @@ def block_expansion(
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||
|
||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||
layer_cnt += 1
|
||||
if (i + 1) % split == 0:
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
if "down_proj" in key or "o_proj" in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||
else:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||
|
||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||
layer_cnt += 1
|
||||
|
||||
for key, value in state_dict.items():
|
||||
@ -113,17 +112,17 @@ def block_expansion(
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print("finetuning_type: freeze")
|
||||
print("freeze_trainable_layers: {}".format(num_expand))
|
||||
print(f"freeze_trainable_layers: {num_expand}")
|
||||
print("use_llama_pro: true")
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
raise KeyError(f"Unable to process key {key}")
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@ -70,19 +69,19 @@ def quantize_loftq(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(loftq_dir))
|
||||
print(f"Adapter weights saved in {loftq_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {loftq_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("quantization_bit: {}".format(loftq_bits))
|
||||
print(f"quantization_bit: {loftq_bits}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@ -54,7 +53,7 @@ def quantize_pissa(
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=lora_target,
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||
)
|
||||
|
||||
# Init PiSSA model
|
||||
@ -65,17 +64,17 @@ def quantize_pissa(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(pissa_dir))
|
||||
print(f"Adapter weights saved in {pissa_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {pissa_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("pissa_init: false")
|
||||
print("pissa_convert: true")
|
||||
|
65
scripts/test_image.py
Normal file
65
scripts/test_image.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||
)
|
||||
messages = []
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Output the color and number of each box."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
result = client.chat.completions.create(messages=messages, model="test")
|
||||
messages.append(result.choices[0].message)
|
||||
print("Round 1:", result.choices[0].message.content)
|
||||
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What kind of flower is this?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
result = client.chat.completions.create(messages=messages, model="test")
|
||||
messages.append(result.choices[0].message)
|
||||
print("Round 2:", result.choices[0].message.content)
|
||||
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
11
setup.py
11
setup.py
@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
@ -28,7 +28,7 @@ def get_version() -> str:
|
||||
|
||||
|
||||
def get_requires() -> List[str]:
|
||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
@ -54,13 +54,14 @@ extra_require = {
|
||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.6.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.6.3"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
"openmind": ["openmind"],
|
||||
"dev": ["pre-commit", "ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
@ -71,7 +72,7 @@ def main():
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use LLM fine-tuning framework",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
|
@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.getenv("API_PORT", "8000"))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
|
||||
|
@ -20,17 +20,17 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.34.2
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.0.2
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
|
||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
|
||||
Force using torchrun: FORCE_TORCHRUN=1
|
||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||
Use openmind: USE_OPENMIND_HUB=1
|
||||
"""
|
||||
|
||||
from .extras.env import VERSION
|
||||
|
@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
||||
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
api_key = os.getenv("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post(
|
||||
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.getenv("API_PORT", "8000"))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
@ -21,7 +21,7 @@ import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
@ -69,8 +69,8 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
@ -84,7 +84,7 @@ def _process_request(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
images = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
@ -111,7 +111,7 @@ def _process_request(
|
||||
else: # web uri
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_stream).convert("RGB")
|
||||
images.append(Image.open(image_stream).convert("RGB"))
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
@ -124,7 +124,7 @@ def _process_request(
|
||||
else:
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools, image
|
||||
return input_messages, system, tools, images or None
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@ -169,7 +169,7 @@ async def create_chat_completion_response(
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
function = Function(name=tool[0], arguments=tool[1])
|
||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
@ -193,8 +193,8 @@ async def create_chat_completion_response(
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
||||
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
|
@ -66,8 +66,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
@ -81,8 +81,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
|
@ -53,7 +53,7 @@ class ChatModel:
|
||||
elif model_args.infer_backend == "vllm":
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
@ -64,15 +64,15 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||
)
|
||||
return task.result()
|
||||
|
||||
@ -81,28 +81,28 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@ -115,14 +115,14 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
logger.warning("There is no current event loop, creating a new one.")
|
||||
logger.warning_once("There is no current event loop, creating a new one.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if not any(VIDEO_PLACEHOLDER not in message["content"] for message in messages):
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if image is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
if images is not None:
|
||||
if not any(IMAGE_PLACEHOLDER not in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if image is not None: # add image features
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
||||
if images is not None: # add image features
|
||||
image_data = []
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data = {"image": image}
|
||||
image_data.append(image)
|
||||
|
||||
multi_modal_data = {"image": image_data}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
|
@ -22,8 +22,8 @@ from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras import logging
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
@ -47,7 +47,7 @@ USAGE = (
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
@ -56,7 +56,7 @@ WELCOME = (
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
@ -86,19 +86,19 @@ def main():
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
).format(
|
||||
nnodes=os.environ.get("NNODES", "1"),
|
||||
node_rank=os.environ.get("RANK", "0"),
|
||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
nnodes=os.getenv("NNODES", "1"),
|
||||
node_rank=os.getenv("NODE_RANK", "0"),
|
||||
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
@ -118,4 +118,4 @@ def main():
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
||||
raise NotImplementedError(f"Unknown command: {command}.")
|
||||
|
@ -16,7 +16,7 @@ import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
@ -29,45 +29,51 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(
|
||||
images: Sequence["ImageInput"],
|
||||
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["ImageInput"]]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(images) == 0:
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
elif len(images) == 0:
|
||||
return None
|
||||
else:
|
||||
images = images[:]
|
||||
|
||||
images = images[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def _convert_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["VideoInput"]]:
|
||||
r"""
|
||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(videos) == 0:
|
||||
if not isinstance(videos, list):
|
||||
videos = [videos]
|
||||
elif len(videos) == 0:
|
||||
return None
|
||||
else:
|
||||
videos = videos[:]
|
||||
|
||||
videos = videos[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(videos)):
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||
|
||||
return videos
|
||||
|
||||
@ -161,7 +167,7 @@ def convert_sharegpt(
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
@ -171,7 +177,7 @@ def convert_sharegpt(
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
@ -192,7 +198,7 @@ def convert_sharegpt(
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
@ -205,7 +211,7 @@ def convert_sharegpt(
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update(mm_inputs)
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"input_ids": feature[f"{key}_input_ids"],
|
||||
"attention_mask": feature[f"{key}_attention_mask"],
|
||||
"labels": feature[f"{key}_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
@ -56,12 +56,12 @@ def merge_dataset(
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
@ -70,7 +70,7 @@ def merge_dataset(
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
|
@ -83,14 +83,14 @@ class StringFormatter(Formatter):
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError("Expected a string, got {}".format(value))
|
||||
raise RuntimeError(f"Expected a string, got {value}")
|
||||
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
functions = []
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
|
||||
elements = []
|
||||
for name, arguments in functions:
|
||||
@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
|
||||
tools = json.loads(content)
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
|
@ -20,8 +20,8 @@ import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _load_single_dataset(
|
||||
@ -51,9 +51,9 @@ def _load_single_dataset(
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
@ -69,25 +69,24 @@ def _load_single_dataset(
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
raise ValueError(f"File {local_path} not found.")
|
||||
|
||||
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||
if data_path is None:
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
|
||||
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||
raise ValueError("File types should be identical.")
|
||||
else:
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
@ -98,10 +97,27 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
use_streaming=data_args.streaming,
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind import OmDataset # type: ignore
|
||||
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||
dataset = OmDataset.load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.om_hub_token,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
@ -111,13 +127,10 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
@ -128,7 +141,7 @@ def _load_single_dataset(
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
@ -224,9 +237,9 @@ def get_dataset(
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
@ -277,8 +290,8 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@ -110,7 +111,7 @@ class BasePlugin:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
@ -157,6 +158,7 @@ class BasePlugin:
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||
input_dict = {"images": None} # default key
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@ -174,10 +176,16 @@ class BasePlugin:
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
mm_inputs = {}
|
||||
if image_processor != video_processor:
|
||||
if input_dict.get("images") is not None:
|
||||
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||
if input_dict.get("videos") is not None:
|
||||
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
@ -218,6 +226,14 @@ class BasePlugin:
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
|
||||
Arguments:
|
||||
images: a list of image inputs, shape (num_images,)
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
seqlens: number of tokens in each sample, shape (batch_size,)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return {}
|
||||
@ -245,7 +261,123 @@ class LlavaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class LlavaNextPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
if "pixel_values" in mm_inputs:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
res = self._get_mm_inputs(images, videos, processor)
|
||||
return res
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while self.image_token in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if "pixel_values_videos" in mm_inputs:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -284,7 +416,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -324,6 +456,68 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
image_break_token = getattr(processor, "image_break_token")
|
||||
image_end_token = getattr(processor, "image_end_token")
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if image_input_sizes is None:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@ -369,7 +563,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
@ -382,7 +576,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
@ -396,10 +590,73 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_frames = 0
|
||||
exist_images = "pixel_values_images" in mm_inputs
|
||||
exist_videos = "pixel_values_videos" in mm_inputs
|
||||
if exist_videos or exist_images:
|
||||
if exist_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
if exist_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if processor.vision_feature_select_strategy == "default":
|
||||
image_seqlen -= 1
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while self.image_token in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(self.image_token, "{{image}}", 1)
|
||||
while self.video_token in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(self.video_token, "{{video}}", 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
@ -420,8 +677,12 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
}
|
||||
|
||||
|
||||
@ -432,6 +693,6 @@ def get_mm_plugin(
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token)
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.misc import use_modelscope, use_openmind
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -30,7 +30,7 @@ class DatasetAttr:
|
||||
"""
|
||||
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||
|
||||
dataset_info = None
|
||||
|
||||
dataset_list: List["DatasetAttr"] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None: # dataset_dir is ONLINE
|
||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||
if use_modelscope():
|
||||
load_from = "ms_hub"
|
||||
elif use_openmind():
|
||||
load_from = "om_hub"
|
||||
else:
|
||||
load_from = "hf_hub"
|
||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||
dataset_list.append(dataset_attr)
|
||||
continue
|
||||
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
has_om_url = "om_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
if has_hf_url or has_ms_url or has_om_url:
|
||||
if has_ms_url and (use_modelscope() or not has_hf_url):
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
elif has_om_url and (use_openmind() or not has_hf_url):
|
||||
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
|
@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
logger.warning_rank0("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
||||
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
||||
|
@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
||||
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||
|
@ -15,7 +15,7 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import get_mm_plugin
|
||||
@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
||||
from .mm_plugin import BasePlugin
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -49,6 +49,7 @@ class Template:
|
||||
stop_words: List[str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
@ -146,7 +147,7 @@ class Template:
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||
|
||||
return token_ids
|
||||
|
||||
@ -214,6 +215,7 @@ def _register_template(
|
||||
stop_words: Sequence[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = True,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
) -> None:
|
||||
r"""
|
||||
@ -263,6 +265,7 @@ def _register_template(
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
replace_jinja_template=replace_jinja_template,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
|
||||
def _jinja_escape(content: str) -> str:
|
||||
@ -353,24 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
||||
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = TEMPLATES.get(data_args.template, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
||||
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
@ -388,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||
|
||||
if stop_words:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError:
|
||||
logger.info("Cannot add this chat template to tokenizer.")
|
||||
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError as e:
|
||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
|
||||
return template
|
||||
|
||||
@ -640,6 +641,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="exaone",
|
||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
@ -664,6 +673,7 @@ _register_template(
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
@ -681,6 +691,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="index",
|
||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
@ -740,6 +758,7 @@ _register_template(
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
@ -754,6 +773,107 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
@ -831,6 +951,14 @@ _register_template(
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
_register_template(
|
||||
name="pixtral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@ -840,6 +968,7 @@ _register_template(
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
@ -852,6 +981,7 @@ _register_template(
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
@ -907,6 +1037,17 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="video_llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
|
@ -177,6 +177,6 @@ TOOLS = {
|
||||
def get_tool_utils(name: str) -> "ToolUtils":
|
||||
tool_utils = TOOLS.get(name, None)
|
||||
if tool_utils is None:
|
||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
||||
raise ValueError(f"Tool utils `{name}` not found.")
|
||||
|
||||
return tool_utils
|
||||
|
@ -87,7 +87,7 @@ class Evaluator:
|
||||
token=self.model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
@ -139,7 +139,7 @@ class Evaluator:
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
for category_name, category_correct in category_corrects.items()
|
||||
if len(category_correct)
|
||||
]
|
||||
|
@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
assert eval_template is not None, f"Template {name} does not exist."
|
||||
return eval_template
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
@ -47,7 +48,7 @@ FILEEXT2TYPE = {
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
IMAGE_PLACEHOLDER = "<image>"
|
||||
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
VIDEO_PLACEHOLDER = "<video>"
|
||||
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
@ -107,6 +108,7 @@ VISION_MODELS = set()
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
OPENMIND = "om"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
@ -114,17 +116,12 @@ def register_model_group(
|
||||
template: Optional[str] = None,
|
||||
vision: bool = False,
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
if prefix is None:
|
||||
prefix = name.split("-")[0]
|
||||
else:
|
||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if template is not None:
|
||||
DEFAULT_TEMPLATE[prefix] = template
|
||||
if vision:
|
||||
VISION_MODELS.add(prefix)
|
||||
if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
|
||||
DEFAULT_TEMPLATE[name] = template
|
||||
if vision:
|
||||
VISION_MODELS.add(name)
|
||||
|
||||
|
||||
register_model_group(
|
||||
@ -168,14 +165,17 @@ register_model_group(
|
||||
"Baichuan2-13B-Base": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
|
||||
},
|
||||
"Baichuan2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
|
||||
},
|
||||
"Baichuan2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
||||
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
|
||||
},
|
||||
},
|
||||
template="baichuan2",
|
||||
@ -274,27 +274,27 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-1.3B": {
|
||||
"Chinese-Llama-2-1.3B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B": {
|
||||
"Chinese-Llama-2-7B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B": {
|
||||
"Chinese-Llama-2-13B": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
||||
},
|
||||
"ChineseLLaMA2-1.3B-Chat": {
|
||||
"Chinese-Alpaca-2-1.3B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
||||
},
|
||||
"ChineseLLaMA2-7B-Chat": {
|
||||
"Chinese-Alpaca-2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
||||
},
|
||||
"ChineseLLaMA2-13B-Chat": {
|
||||
"Chinese-Alpaca-2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
||||
},
|
||||
@ -450,25 +450,25 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"DeepSeekCoder-6.7B-Base": {
|
||||
"DeepSeek-Coder-6.7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
||||
},
|
||||
"DeepSeekCoder-7B-Base": {
|
||||
"DeepSeek-Coder-7B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Base": {
|
||||
"DeepSeek-Coder-33B-Base": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
||||
},
|
||||
"DeepSeekCoder-6.7B-Instruct": {
|
||||
"DeepSeek-Coder-6.7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
||||
},
|
||||
"DeepSeekCoder-7B-Instruct": {
|
||||
"DeepSeek-Coder-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
||||
},
|
||||
"DeepSeekCoder-33B-Instruct": {
|
||||
"DeepSeek-Coder-33B-Instruct": {
|
||||
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
||||
},
|
||||
@ -477,6 +477,16 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"EXAONE-3.0-7.8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
|
||||
},
|
||||
},
|
||||
template="exaone",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": {
|
||||
@ -550,10 +560,12 @@ register_model_group(
|
||||
"Gemma-2-2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
|
||||
},
|
||||
"Gemma-2-9B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
|
||||
},
|
||||
"Gemma-2-27B-Instruct": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||
@ -573,6 +585,7 @@ register_model_group(
|
||||
"GLM-4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
|
||||
},
|
||||
"GLM-4-9B-1M-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
||||
@ -583,6 +596,33 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Index-1.9B-Chat": {
|
||||
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
|
||||
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
|
||||
},
|
||||
"Index-1.9B-Character-Chat": {
|
||||
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
|
||||
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
|
||||
},
|
||||
"Index-1.9B-Base": {
|
||||
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
|
||||
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
|
||||
},
|
||||
"Index-1.9B-Base-Pure": {
|
||||
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
|
||||
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
|
||||
},
|
||||
"Index-1.9B-Chat-32K": {
|
||||
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
|
||||
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
|
||||
},
|
||||
},
|
||||
template="index",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@ -624,16 +664,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
||||
},
|
||||
},
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM2.5-1.8B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
|
||||
},
|
||||
"InternLM2.5-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
||||
@ -642,22 +676,27 @@ register_model_group(
|
||||
"InternLM2.5-20B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
|
||||
},
|
||||
"InternLM2.5-1.8B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
|
||||
},
|
||||
"InternLM2.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
|
||||
},
|
||||
"InternLM2.5-7B-1M-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
|
||||
},
|
||||
"InternLM2.5-20B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
|
||||
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
|
||||
},
|
||||
},
|
||||
template="intern2",
|
||||
@ -686,19 +725,19 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": {
|
||||
"Llama-7B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
||||
},
|
||||
"LLaMA-13B": {
|
||||
"Llama-13B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
||||
},
|
||||
"LLaMA-30B": {
|
||||
"Llama-30B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
||||
},
|
||||
"LLaMA-65B": {
|
||||
"Llama-65B": {
|
||||
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
||||
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
||||
},
|
||||
@ -708,27 +747,27 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": {
|
||||
"Llama-2-7B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
||||
},
|
||||
"LLaMA2-13B": {
|
||||
"Llama-2-13B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
||||
},
|
||||
"LLaMA2-70B": {
|
||||
"Llama-2-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
||||
},
|
||||
"LLaMA2-7B-Chat": {
|
||||
"Llama-2-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
||||
},
|
||||
"LLaMA2-13B-Chat": {
|
||||
"Llama-2-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
||||
},
|
||||
"LLaMA2-70B-Chat": {
|
||||
"Llama-2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
||||
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
||||
},
|
||||
@ -739,60 +778,78 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA3-8B": {
|
||||
"Llama-3-8B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
||||
},
|
||||
"LLaMA3-70B": {
|
||||
"Llama-3-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
||||
},
|
||||
"LLaMA3-8B-Instruct": {
|
||||
"Llama-3-8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
||||
},
|
||||
"LLaMA3-70B-Instruct": {
|
||||
"Llama-3-70B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
||||
},
|
||||
"LLaMA3-8B-Chinese-Chat": {
|
||||
"Llama-3-8B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
|
||||
},
|
||||
"LLaMA3-70B-Chinese-Chat": {
|
||||
"Llama-3-70B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA3.1-8B": {
|
||||
"Llama-3.1-8B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
|
||||
},
|
||||
"LLaMA3.1-70B": {
|
||||
"Llama-3.1-70B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
|
||||
},
|
||||
"LLaMA3.1-405B": {
|
||||
"Llama-3.1-405B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
|
||||
},
|
||||
"LLaMA3.1-8B-Instruct": {
|
||||
"Llama-3.1-8B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"LLaMA3.1-70B-Instruct": {
|
||||
"Llama-3.1-70B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
||||
},
|
||||
"LLaMA3.1-405B-Instruct": {
|
||||
"Llama-3.1-405B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
|
||||
},
|
||||
"Llama-3.1-8B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
|
||||
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
|
||||
},
|
||||
"Llama-3.1-70B-Chinese-Chat": {
|
||||
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
|
||||
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
|
||||
},
|
||||
"Llama-3.2-1B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
|
||||
},
|
||||
"Llama-3.2-3B": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
|
||||
},
|
||||
"Llama-3.2-1B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
|
||||
},
|
||||
"Llama-3.2-3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
},
|
||||
template="llama3",
|
||||
)
|
||||
@ -800,11 +857,13 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA1.5-7B-Chat": {
|
||||
"LLaVA-1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
|
||||
},
|
||||
"LLaVA1.5-13B-Chat": {
|
||||
"LLaVA-1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
|
||||
},
|
||||
},
|
||||
template="llava",
|
||||
@ -812,6 +871,117 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
|
||||
},
|
||||
"LLaVA-NeXT-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-Mistral-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_mistral",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-Llama3-8B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_llama3",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_yi",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-72B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
|
||||
},
|
||||
"LLaVA-NeXT-110B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_qwen",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-Video-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
|
||||
},
|
||||
"LLaVA-NeXT-Video-7B-DPO-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_video",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-Video-7B-32k-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_video_mistral",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-NeXT-Video-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
|
||||
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
|
||||
},
|
||||
"LLaVA-NeXT-Video-34B-DPO-Chat": {
|
||||
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
|
||||
},
|
||||
},
|
||||
template="llava_next_video_yi",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-2B-SFT-Chat": {
|
||||
@ -832,6 +1002,7 @@ register_model_group(
|
||||
"MiniCPM3-4B-Chat": {
|
||||
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
|
||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
|
||||
},
|
||||
},
|
||||
template="cpm3",
|
||||
@ -1005,27 +1176,27 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi3-4B-4k-Instruct": {
|
||||
"Phi-3-4B-4k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
||||
},
|
||||
"Phi3-4B-128k-Instruct": {
|
||||
"Phi-3-4B-128k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
||||
},
|
||||
"Phi3-7B-8k-Instruct": {
|
||||
"Phi-3-7B-8k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
||||
},
|
||||
"Phi3-7B-128k-Instruct": {
|
||||
"Phi-3-7B-128k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
||||
},
|
||||
"Phi3-14B-8k-Instruct": {
|
||||
"Phi-3-14B-8k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
||||
},
|
||||
"Phi3-14B-128k-Instruct": {
|
||||
"Phi-3-14B-128k-Instruct": {
|
||||
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
||||
},
|
||||
@ -1034,6 +1205,18 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-Chat": {
|
||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||
}
|
||||
},
|
||||
template="pixtral",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
@ -1068,35 +1251,35 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
||||
},
|
||||
"Qwen-1.8B-int8-Chat": {
|
||||
"Qwen-1.8B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
||||
},
|
||||
"Qwen-1.8B-int4-Chat": {
|
||||
"Qwen-1.8B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
||||
},
|
||||
"Qwen-7B-int8-Chat": {
|
||||
"Qwen-7B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
||||
},
|
||||
"Qwen-7B-int4-Chat": {
|
||||
"Qwen-7B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
||||
},
|
||||
"Qwen-14B-int8-Chat": {
|
||||
"Qwen-14B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
||||
},
|
||||
"Qwen-14B-int4-Chat": {
|
||||
"Qwen-14B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
||||
},
|
||||
"Qwen-72B-int8-Chat": {
|
||||
"Qwen-72B-Chat-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
||||
},
|
||||
"Qwen-72B-int4-Chat": {
|
||||
"Qwen-72B-Chat-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
||||
},
|
||||
@ -1179,75 +1362,75 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
||||
},
|
||||
"Qwen1.5-0.5B-int8-Chat": {
|
||||
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-0.5B-int4-Chat": {
|
||||
"Qwen1.5-0.5B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-1.8B-int8-Chat": {
|
||||
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-1.8B-int4-Chat": {
|
||||
"Qwen1.5-1.8B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-4B-int8-Chat": {
|
||||
"Qwen1.5-4B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-4B-int4-Chat": {
|
||||
"Qwen1.5-4B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-7B-int8-Chat": {
|
||||
"Qwen1.5-7B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-7B-int4-Chat": {
|
||||
"Qwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-14B-int8-Chat": {
|
||||
"Qwen1.5-14B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-14B-int4-Chat": {
|
||||
"Qwen1.5-14B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-32B-int4-Chat": {
|
||||
"Qwen1.5-32B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-72B-int8-Chat": {
|
||||
"Qwen1.5-72B-Chat-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
||||
},
|
||||
"Qwen1.5-72B-int4-Chat": {
|
||||
"Qwen1.5-72B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-110B-int4-Chat": {
|
||||
"Qwen1.5-110B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
||||
},
|
||||
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
||||
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
||||
},
|
||||
"Qwen1.5-Code-7B": {
|
||||
"CodeQwen1.5-7B": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
||||
},
|
||||
"Qwen1.5-Code-7B-Chat": {
|
||||
"CodeQwen1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
||||
},
|
||||
"Qwen1.5-Code-7B-int4-Chat": {
|
||||
"CodeQwen1.5-7B-Chat-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
||||
},
|
||||
@ -1281,14 +1464,17 @@ register_model_group(
|
||||
"Qwen2-0.5B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
|
||||
},
|
||||
"Qwen2-1.5B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
|
||||
},
|
||||
"Qwen2-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
|
||||
},
|
||||
"Qwen2-72B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
||||
@ -1568,51 +1754,53 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2VL-2B-Instruct": {
|
||||
"Qwen2-VL-2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
|
||||
},
|
||||
"Qwen2VL-7B-Instruct": {
|
||||
"Qwen2-VL-7B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
|
||||
},
|
||||
"Qwen2VL-72B-Instruct": {
|
||||
"Qwen2-VL-72B-Instruct": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
|
||||
},
|
||||
"Qwen2VL-2B-Instruct-GPTQ-Int8": {
|
||||
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2VL-2B-Instruct-GPTQ-Int4": {
|
||||
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
||||
},
|
||||
"Qwen2VL-2B-Instruct-AWQ": {
|
||||
"Qwen2-VL-2B-Instruct-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2VL-7B-Instruct-GPTQ-Int8": {
|
||||
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2VL-7B-Instruct-GPTQ-Int4": {
|
||||
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
|
||||
},
|
||||
"Qwen2VL-7B-Instruct-AWQ": {
|
||||
"Qwen2-VL-7B-Instruct-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2VL-72B-Instruct-GPTQ-Int8": {
|
||||
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2VL-72B-Instruct-GPTQ-Int4": {
|
||||
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
|
||||
},
|
||||
"Qwen2VL-72B-Instruct-AWQ": {
|
||||
"Qwen2-VL-72B-Instruct-AWQ": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
|
||||
},
|
||||
@ -1673,10 +1861,12 @@ register_model_group(
|
||||
"TeleChat-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
|
||||
},
|
||||
"TeleChat-12B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
||||
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
|
||||
},
|
||||
"TeleChat-12B-v2-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
||||
@ -1689,11 +1879,11 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Vicuna1.5-7B-Chat": {
|
||||
"Vicuna-v1.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
||||
},
|
||||
"Vicuna1.5-13B-Chat": {
|
||||
"Vicuna-v1.5-13B-Chat": {
|
||||
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
||||
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
||||
},
|
||||
@ -1702,6 +1892,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Video-LLaVA-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
|
||||
},
|
||||
},
|
||||
template="video_llava",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XuanYuan-6B": {
|
||||
@ -1712,7 +1913,7 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
||||
},
|
||||
"XuanYuan-2-70B": {
|
||||
"XuanYuan2-70B": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
||||
},
|
||||
@ -1724,31 +1925,31 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
||||
},
|
||||
"XuanYuan-2-70B-Chat": {
|
||||
"XuanYuan2-70B-Chat": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
||||
},
|
||||
"XuanYuan-6B-int8-Chat": {
|
||||
"XuanYuan-6B-Chat-8bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-6B-int4-Chat": {
|
||||
"XuanYuan-6B-Chat-4bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-70B-int8-Chat": {
|
||||
"XuanYuan-70B-Chat-8bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-70B-int4-Chat": {
|
||||
"XuanYuan-70B-Chat-4bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int8-Chat": {
|
||||
"XuanYuan2-70B-Chat-8bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
||||
},
|
||||
"XuanYuan-2-70B-int4-Chat": {
|
||||
"XuanYuan2-70B-Chat-4bit": {
|
||||
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
||||
},
|
||||
@ -1853,19 +2054,19 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
||||
},
|
||||
"Yi-6B-int8-Chat": {
|
||||
"Yi-6B-Chat-8bits": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-6B-int4-Chat": {
|
||||
"Yi-6B-Chat-4bits": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
"Yi-34B-Chat-8bits": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int4-Chat": {
|
||||
"Yi-34B-Chat-4bits": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
@ -1884,6 +2085,7 @@ register_model_group(
|
||||
"Yi-1.5-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
||||
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
|
||||
},
|
||||
"Yi-1.5-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
||||
@ -1916,10 +2118,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"YiVL-6B-Chat": {
|
||||
"Yi-VL-6B-Chat": {
|
||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
||||
},
|
||||
"YiVL-34B-Chat": {
|
||||
"Yi-VL-34B-Chat": {
|
||||
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
||||
},
|
||||
},
|
||||
|
@ -72,4 +72,4 @@ def print_env() -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||
|
@ -20,6 +20,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
self._formatter = logging.Formatter(
|
||||
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
self.setLevel(logging.INFO)
|
||||
self.setFormatter(formatter)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||
if os.path.exists(self.running_log):
|
||||
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
log_entry = self.format(record)
|
||||
log_entry = self._formatter.format(record)
|
||||
self.thread_pool.submit(self._write_log, log_entry)
|
||||
|
||||
def close(self) -> None:
|
||||
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
|
||||
return super().close()
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports info_rank0 and warning_once.
|
||||
"""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
def warning_rank0(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
def warning_once(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Returns the default logging level.
|
||||
@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||
|
||||
return _default_log_level
|
||||
|
||||
@ -84,7 +99,7 @@ def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "logging.Logger":
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
if _default_handler:
|
||||
if _default_handler: # already configured
|
||||
return
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.setFormatter(formatter)
|
||||
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""
|
||||
Adds a handler to the root logger.
|
||||
"""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""
|
||||
Removes a handler to the root logger.
|
||||
"""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
|
||||
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_rank0 = info_rank0
|
||||
logging.Logger.warning_rank0 = warning_rank0
|
||||
logging.Logger.warning_once = warning_once
|
||||
|
@ -32,7 +32,7 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .logging import get_logger
|
||||
from . import logging
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@ -76,12 +76,12 @@ def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
@ -231,18 +231,35 @@ def torch_gc() -> None:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
|
||||
return model_args.model_name_or_path
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
|
||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from .logging import get_logger
|
||||
from . import logging
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
logger.warning_rank0(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.title(f"training {key} of {save_dictionary}")
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
|
@ -41,6 +41,10 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
image_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
@ -111,7 +115,13 @@ class DataArguments:
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to save or load the tokenized datasets. "
|
||||
"If tokenized_path not exists, it will save the tokenized datasets. "
|
||||
"If tokenized_path exists, it will load the tokenized datasets."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -123,6 +133,9 @@ class DataArguments:
|
||||
self.dataset = split_arg(self.dataset)
|
||||
self.eval_dataset = split_arg(self.eval_dataset)
|
||||
|
||||
if self.image_dir is None:
|
||||
self.image_dir = self.dataset_dir
|
||||
|
||||
if self.dataset is None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -267,6 +267,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
print_param_status: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||
@ -308,20 +312,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
for attr in fields(cls):
|
||||
if not attr.init:
|
||||
arg_dict.pop(attr.name)
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
init_args, lazy_args = {}, {}
|
||||
for attr in fields(source):
|
||||
if attr.init:
|
||||
init_args[attr.name] = getattr(source, attr.name)
|
||||
else:
|
||||
lazy_args[attr.name] = getattr(source, attr.name)
|
||||
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||
return new_arg
|
||||
init_args.update(kwargs)
|
||||
result = cls(**init_args)
|
||||
for name, value in lazy_args.items():
|
||||
setattr(result, name, value)
|
||||
|
||||
return result
|
||||
|
@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
def _set_transformers_logging() -> None:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
@ -104,7 +103,7 @@ def _verify_model_args(
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
|
||||
@ -123,7 +122,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
|
||||
require_version("vllm>=0.4.3,<=0.6.3", "To fix: pip install vllm>=0.4.3,<=0.6.3")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and not data_args.packing:
|
||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||
logger.warning_rank0(
|
||||
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||
)
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
||||
logger.warning_rank0(
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning(
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
|
@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _setup_full_tuning(
|
||||
@ -45,7 +45,7 @@ def _setup_full_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
logger.info_rank0("Fine-tuning method: Full")
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
logger.info_rank0("Fine-tuning method: Freeze")
|
||||
if hasattr(model.config, "text_config"): # composite models
|
||||
config = getattr(model.config, "text_config")
|
||||
else:
|
||||
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
|
||||
|
||||
def _setup_lora_tuning(
|
||||
@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
if is_trainable:
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
|
||||
@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
@ -236,11 +236,11 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
if finetuning_args.pissa_iter == -1:
|
||||
logger.info("Using PiSSA initialization.")
|
||||
logger.info_rank0("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
@ -284,11 +284,11 @@ def init_adapter(
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
@ -300,6 +300,6 @@ def init_adapter(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
|
||||
|
||||
return model
|
||||
|
@ -18,15 +18,15 @@ import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .model_utils.visual import get_image_seqlen
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizerModule(TypedDict):
|
||||
@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
skip_check_imports()
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer.
|
||||
Loads pretrained tokenizer and optionally loads processor.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@ -82,33 +82,30 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise OSError("Failed to load tokenizer.") from e
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
except Exception:
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.debug(f"Processor was not found: {e}.")
|
||||
processor = None
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||
if "Processor" not in processor.__class__.__name__:
|
||||
if processor is not None and "Processor" not in processor.__class__.__name__:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
@ -135,6 +132,7 @@ def load_model(
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
@ -157,7 +155,7 @@ def load_model(
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
if model_args.train_from_scratch:
|
||||
model = load_class.from_config(config)
|
||||
model = load_class.from_config(config, trust_remote_code=True)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
|
||||
@ -182,7 +180,7 @@ def load_model(
|
||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
@ -200,9 +198,9 @@ def load_model(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:,}".format(all_param)
|
||||
param_stats = f"all params: {all_param:,}"
|
||||
|
||||
logger.info(param_stats)
|
||||
logger.info_rank0(param_stats)
|
||||
|
||||
if model_args.print_param_status:
|
||||
for name, param in model.named_parameters():
|
||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(
|
||||
@ -37,13 +37,16 @@ def configure_attn_implementation(
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
||||
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
||||
logger.warning_rank0(
|
||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
@ -53,18 +56,18 @@ def configure_attn_implementation(
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "flash_attention_2"
|
||||
else:
|
||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
||||
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
|
||||
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||
@ -79,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
if attn_implementation == "flash_attention_2":
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
|
||||
elif attn_implementation == "sdpa":
|
||||
logger.info("Using torch SDPA for faster training and inference.")
|
||||
logger.info_rank0("Using torch SDPA for faster training and inference.")
|
||||
else:
|
||||
logger.info("Using vanilla attention implementation.")
|
||||
logger.info_rank0("Using vanilla attention implementation.")
|
||||
|
@ -19,14 +19,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial, wraps
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import LAYERNORM_NAMES
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
|
||||
@wraps(gradient_checkpointing_func)
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
||||
@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
if not self.supports_gradient_checkpointing:
|
||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
|
||||
if gradient_checkpointing_kwargs is None:
|
||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||
@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
logger.info_rank0("Gradient checkpointing enabled.")
|
||||
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
|
@ -12,9 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -23,10 +24,15 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def apply_liger_kernel(
|
||||
config: "PretrainedConfig",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
require_logits: bool,
|
||||
) -> None:
|
||||
if not is_trainable or not model_args.enable_liger_kernel:
|
||||
return
|
||||
|
||||
@ -48,8 +54,14 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
||||
elif model_type == "qwen2_vl":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
||||
else:
|
||||
logger.warning("Current model does not support liger kernel.")
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
apply_liger_kernel()
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||
logger.info_rank0("Current training stage does not support chunked cross entropy.")
|
||||
kwargs = {"fused_linear_cross_entropy": False}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
@ -86,7 +86,7 @@ def llama_attention_forward(
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
_apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
logger.warning_rank0("Current model does not support shift short attention.")
|
||||
|
@ -14,14 +14,14 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("output_layer")
|
||||
elif model_type == "internlm2":
|
||||
forbidden_modules.add("output")
|
||||
elif model_type in ["llava", "paligemma"]:
|
||||
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||
forbidden_modules.add("multi_modal_projector")
|
||||
elif model_type == "qwen2_vl":
|
||||
forbidden_modules.add("merger")
|
||||
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
@ -67,12 +67,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
|
||||
if num_layers % num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
|
||||
)
|
||||
|
||||
stride = num_layers // num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
|
||||
module_names = []
|
||||
for name, _ in model.named_modules():
|
||||
if any(target_module in name for target_module in target_modules) and any(
|
||||
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user