mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
Compare commits
6 Commits
a754604c11
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eceec8ab69 | ||
|
|
b44f651e09 | ||
|
|
55590f5ece | ||
|
|
a1b1931b4a | ||
|
|
3c17f2722c | ||
|
|
a882e2d5fc |
180
.github/copilot-instructions.md
vendored
Normal file
180
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,180 @@
|
||||
# GitHub Copilot Instructions for LLaMA Factory
|
||||
|
||||
## Project Overview
|
||||
|
||||
LLaMA Factory is an efficient fine-tuning framework for 100+ large language models (LLMs). It provides:
|
||||
- Support for various models: LLaMA, LLaVA, Mistral, Qwen, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
|
||||
- Multiple training methods: pre-training, supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO
|
||||
- Scalable resources: 16-bit full-tuning, freeze-tuning, LoRA and QLoRA variants
|
||||
- Advanced algorithms: GaLore, BAdam, APOLLO, Adam-mini, Muon, OFT, DoRA, etc.
|
||||
- Web UI (LLaMA Board) and CLI interfaces
|
||||
|
||||
### Architecture Versions
|
||||
|
||||
LLaMA Factory has two parallel architectures that can be switched via the `USE_V1` environment variable:
|
||||
|
||||
**v0 (default)** - File hierarchy:
|
||||
- `api`, `webui` → `chat`, `eval`, `train` → `data`, `model` → `hparams` → `extras`
|
||||
|
||||
**v1** - File hierarchy:
|
||||
- `trainers` → `core` → `accelerator`, `plugins`, `config` → `utils`
|
||||
|
||||
Set `USE_V1=1` to enable v1 architecture.
|
||||
|
||||
## Code Structure
|
||||
|
||||
### v0 Architecture (Default)
|
||||
|
||||
- `src/llamafactory/` - Main package directory
|
||||
- `api/` - OpenAI-style API implementation
|
||||
- `chat/` - Chat interface implementation
|
||||
- `cli.py` - Command-line interface
|
||||
- `data/` - Data processing and dataset handling
|
||||
- `eval/` - Model evaluation utilities
|
||||
- `extras/` - Additional utilities and helpers
|
||||
- `hparams/` - Hyperparameter definitions
|
||||
- `model/` - Model loading, patching, and utilities
|
||||
- `train/` - Training pipeline implementation
|
||||
- `webui/` - Gradio-based web interface
|
||||
- `src/train.py` - Training entry script (delegates to `llamafactory.train.tuner`)
|
||||
- `src/webui.py` - Web UI entry script (delegates to `llamafactory.webui.interface`)
|
||||
- `src/api.py` - API server entry script (delegates to `llamafactory.api.app`)
|
||||
- `tests/` - Test suite
|
||||
- `examples/` - Example configurations for various training scenarios
|
||||
- `data/` - Dataset definitions and examples
|
||||
|
||||
### v1 Architecture (USE_V1=1)
|
||||
|
||||
- `src/llamafactory/v1/` - Version 1 package directory
|
||||
- `trainers/` - Training implementations
|
||||
- `core/` - Core training utilities
|
||||
- `accelerator/` - Acceleration and distributed training
|
||||
- `plugins/` - Pluggable components (model, data, sampler, trainer)
|
||||
- `config/` - Configuration management
|
||||
- `utils/` - Utility functions
|
||||
|
||||
## Development Practices
|
||||
|
||||
### Code Style
|
||||
|
||||
- Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
|
||||
- Use ruff for linting and formatting
|
||||
- Line length: 119 characters
|
||||
- Indentation: 4 spaces
|
||||
- Quote style: double quotes
|
||||
- Use Google-style docstrings for documentation
|
||||
|
||||
### Import Organization
|
||||
|
||||
- Known first-party: `llamafactory`
|
||||
- Known third-party: `accelerate`, `datasets`, `gradio`, `numpy`, `peft`, `torch`, `transformers`, `trl`
|
||||
- Use 2 blank lines after imports
|
||||
|
||||
### Quality Checks
|
||||
|
||||
Before committing code, run:
|
||||
```bash
|
||||
make style # Auto-fix style issues
|
||||
make quality # Check code quality
|
||||
make test # Run test suite
|
||||
```
|
||||
|
||||
Or use the combined command:
|
||||
```bash
|
||||
make commit # Run pre-commit hooks
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
- Use pytest for testing
|
||||
- Tests are located in `tests/` and `tests_v1/` directories
|
||||
- Run tests with: `make test` (which runs `WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/`)
|
||||
- Disable wandb during testing to avoid external dependencies
|
||||
- **Note**: Training configurations require GPU machines, so training is typically not tested end-to-end. Use `make test` to validate file-level functionality.
|
||||
|
||||
### Building
|
||||
|
||||
Build the package with:
|
||||
```bash
|
||||
pip3 install build && python3 -m build
|
||||
```
|
||||
|
||||
### License
|
||||
|
||||
- All source files must include the Apache 2.0 license header
|
||||
- Check license headers with: `make license`
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Configuration Files
|
||||
|
||||
- Training configurations are typically YAML or JSON files in `examples/` directory
|
||||
- Hyperparameters are defined using dataclasses in `src/llamafactory/hparams/`
|
||||
|
||||
### Model Support
|
||||
|
||||
- New model support is added through model patches in `src/llamafactory/model/`
|
||||
- Visual models use the visual utilities in `src/llamafactory/model/model_utils/visual.py`
|
||||
- Quantization support is in `src/llamafactory/model/model_utils/quantization.py`
|
||||
|
||||
### Data Processing
|
||||
|
||||
- Dataset definitions are in `data/dataset_info.json`
|
||||
- Data templates and processors are in `src/llamafactory/data/`
|
||||
|
||||
### Training
|
||||
|
||||
- Training pipelines are in `src/llamafactory/train/`
|
||||
- Support for different training methods: SFT, DPO, PPO, RM, PT, KTO, ORPO
|
||||
|
||||
## Key Dependencies
|
||||
|
||||
- Python >= 3.9.0
|
||||
- PyTorch and transformers for model handling
|
||||
- datasets for data processing
|
||||
- peft for parameter-efficient fine-tuning
|
||||
- accelerate for distributed training
|
||||
- gradio for web UI
|
||||
- trl for reinforcement learning
|
||||
- Optional: vllm/sglang for inference, flash-attention-2, unsloth, liger-kernel
|
||||
|
||||
## Entry Points
|
||||
|
||||
- **CLI Training**: `llamafactory-cli train --config examples/train_lora/llama3_lora_sft.yaml`
|
||||
- **Web UI**: `llamafactory-cli webui` or `python src/webui.py`
|
||||
- **API Server**: `llamafactory-cli api` or `python src/api.py`
|
||||
- **Chat Interface**: `llamafactory-cli chat --model_name_or_path MODEL_PATH`
|
||||
|
||||
## Environment Setup
|
||||
|
||||
For development:
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The project supports multiple backends: default PyTorch, vLLM, SGLang
|
||||
- Megatron-core training is supported via mcore_adapter
|
||||
- SwanLab and W&B are supported for experiment tracking
|
||||
- Docker support is available with pre-built images
|
||||
- Day-0/Day-1 support for latest cutting-edge models
|
||||
- Multi-modal support for vision and audio understanding tasks
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a development branch
|
||||
3. Set up development environment with `pip install -e ".[dev]"`
|
||||
4. Make changes following the style guide
|
||||
5. Run quality checks: `make style && make quality`
|
||||
6. Run tests: `make test`
|
||||
7. Submit a pull request
|
||||
|
||||
## Common Commands
|
||||
|
||||
- `make style` - Format code
|
||||
- `make quality` - Run linters
|
||||
- `make test` - Run tests
|
||||
- `make commit` - Install and run pre-commit hooks
|
||||
- `make license` - Check license headers
|
||||
26
.github/workflows/docker.yml
vendored
26
.github/workflows/docker.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "docker/**"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "docker/**"
|
||||
- ".github/workflows/*.yml"
|
||||
release:
|
||||
@@ -29,16 +29,13 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- device: "cuda"
|
||||
npu_type: ""
|
||||
- device: "npu"
|
||||
npu_type: "a2"
|
||||
- device: "npu"
|
||||
npu_type: "a3"
|
||||
- device: "npu-a2"
|
||||
- device: "npu-a3"
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}-${{ matrix.npu_type }}
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
environment:
|
||||
@@ -55,16 +52,11 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Get llamafactory version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
echo "tag=$(python setup.py --version)" >> "$GITHUB_OUTPUT"
|
||||
echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=latest" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
@@ -93,14 +85,12 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/docker-cuda/Dockerfile
|
||||
build-args: |
|
||||
EXTRAS=metrics,deepspeed,liger-kernel
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
|
||||
|
||||
- name: Build and push Docker image (NPU-A2)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a2' }}
|
||||
if: ${{ matrix.device == 'npu-a2' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
@@ -112,7 +102,7 @@ jobs:
|
||||
quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
|
||||
|
||||
- name: Build and push Docker image (NPU-A3)
|
||||
if: ${{ matrix.device == 'npu' && matrix.npu_type == 'a3' }}
|
||||
if: ${{ matrix.device == 'npu-a3' }}
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
|
||||
7
.github/workflows/publish.yml
vendored
7
.github/workflows/publish.yml
vendored
@@ -23,10 +23,11 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.11"
|
||||
github-token: ${{ github.token }}
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
|
||||
37
.github/workflows/tests.yml
vendored
37
.github/workflows/tests.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
@@ -25,10 +25,9 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
# - "3.13" # enable after trl is upgraded
|
||||
os:
|
||||
- "ubuntu-latest"
|
||||
- "windows-latest"
|
||||
@@ -36,18 +35,15 @@ jobs:
|
||||
transformers:
|
||||
- null
|
||||
include: # test backward compatibility
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.49.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.9"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
exclude: # exclude python 3.9 on macos
|
||||
- python: "3.9"
|
||||
os: "macos-latest"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@@ -63,21 +59,23 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
github-token: ${{ github.token }}
|
||||
enable-cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install ".[dev]"
|
||||
uv venv
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install transformers
|
||||
if: ${{ matrix.transformers }}
|
||||
run: |
|
||||
python -m pip install "transformers==${{ matrix.transformers }}"
|
||||
uv pip install "transformers==${{ matrix.transformers }}"
|
||||
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
@@ -89,18 +87,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: ${{ runner.temp }}/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
20
.github/workflows/tests_npu.yml
vendored
20
.github/workflows/tests_npu.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
pull_request:
|
||||
@@ -15,7 +15,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "requirements.txt"
|
||||
- "pyproject.toml"
|
||||
- "Makefile"
|
||||
- ".github/workflows/*.yml"
|
||||
|
||||
@@ -48,10 +48,15 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[torch-npu,dev]" torch-npu==${{matrix.pytorch_npu}}
|
||||
uv venv
|
||||
uv pip install torch-npu==${{matrix.pytorch_npu}}
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
- name: Install node
|
||||
run: |
|
||||
@@ -70,18 +75,25 @@ jobs:
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check license
|
||||
run: |
|
||||
make license
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Check build
|
||||
run: |
|
||||
make build
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
env:
|
||||
UV_NO_SYNC: 1
|
||||
HF_HOME: /root/.cache/huggingface
|
||||
HF_HUB_OFFLINE: "${{ steps.hf-hub-cache.outputs.cache-hit == 'true' && '1' || '0' }}"
|
||||
|
||||
@@ -1 +1 @@
|
||||
include LICENSE requirements.txt
|
||||
include LICENSE
|
||||
|
||||
24
Makefile
24
Makefile
@@ -1,24 +1,28 @@
|
||||
.PHONY: build commit license quality style test
|
||||
|
||||
check_dirs := scripts src tests tests_v1 setup.py
|
||||
check_dirs := scripts src tests tests_v1
|
||||
|
||||
RUN := $(shell command -v uv >/dev/null 2>&1 && echo "uv run" || echo "")
|
||||
BUILD := $(shell command -v uv >/dev/null 2>&1 && echo "uv build" || echo "python -m build")
|
||||
TOOL := $(shell command -v uv >/dev/null 2>&1 && echo "uvx" || echo "")
|
||||
|
||||
build:
|
||||
pip3 install build && python3 -m build
|
||||
$(BUILD)
|
||||
|
||||
commit:
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
$(TOOL) pre-commit install
|
||||
$(TOOL) pre-commit run --all-files
|
||||
|
||||
license:
|
||||
python3 tests/check_license.py $(check_dirs)
|
||||
$(RUN) python3 tests/check_license.py $(check_dirs)
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs)
|
||||
$(TOOL) ruff format --check $(check_dirs)
|
||||
|
||||
style:
|
||||
ruff check $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
$(TOOL) ruff check $(check_dirs) --fix
|
||||
$(TOOL) ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
WANDB_DISABLED=true $(RUN) pytest -vv --import-mode=importlib tests/ tests_v1/
|
||||
|
||||
23
README.md
23
README.md
@@ -514,10 +514,12 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, openmind, swanlab, dev
|
||||
Optional dependencies available: `metrics`, `deepspeed`. Install with: `pip install -e ".[metrics,deepspeed]"`
|
||||
|
||||
Additional dependencies for specific features are available in `examples/requirements/`.
|
||||
|
||||
#### Install from Docker Image
|
||||
|
||||
@@ -536,13 +538,7 @@ Please refer to [build docker](#build-docker) to build the image yourself.
|
||||
Create an isolated Python environment with [uv](https://github.com/astral-sh/uv):
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
Run LLaMA-Factory in the isolated environment:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -579,7 +575,7 @@ To enable FlashAttention-2 on the Windows platform, please use the script from [
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher and specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please upgrade Python to version 3.10 or higher: `pip install -e . torch-npu==2.7.1`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
@@ -598,8 +594,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
@@ -714,7 +710,6 @@ For CUDA users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host --gpus=all \
|
||||
@@ -731,7 +726,6 @@ For Ascend NPU users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-npu/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=torch-npu,metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host \
|
||||
@@ -756,7 +750,6 @@ For AMD ROCm users:
|
||||
```bash
|
||||
docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
--build-arg EXTRAS=metrics \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -dit --ipc=host \
|
||||
|
||||
20
README_zh.md
20
README_zh.md
@@ -516,10 +516,12 @@ huggingface-cli login
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]" --no-build-isolation
|
||||
pip install -e ".[metrics]" --no-build-isolation
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、openmind、swanlab、dev
|
||||
可选的额外依赖项:`metrics`、`deepspeed`。使用 `pip install -e ".[metrics,deepspeed]"` 安装。
|
||||
|
||||
其他可选依赖项请参考 `examples/requirements/` 目录下的文件。
|
||||
|
||||
#### 从镜像安装
|
||||
|
||||
@@ -538,13 +540,7 @@ docker run -it --rm --gpus=all --ipc=host hiyouga/llamafactory:latest
|
||||
使用 [uv](https://github.com/astral-sh/uv) 创建隔离的 Python 环境:
|
||||
|
||||
```bash
|
||||
uv sync --extra torch --extra metrics --prerelease=allow
|
||||
```
|
||||
|
||||
在环境中运行 LLaMA-Factory:
|
||||
|
||||
```bash
|
||||
uv run --prerelease=allow llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
|
||||
uv run llamafactory-cli webui
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -581,7 +577,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e . torch-npu==2.7.1` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
@@ -600,8 +596,8 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| torch | 2.1.0 | 2.7.1 |
|
||||
| torch-npu | 2.1.0 | 2.7.1 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| vllm-ascend | - | 0.7.3 |
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=metrics
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
|
||||
@@ -27,17 +26,13 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
RUN pip install --no-cache-dir --no-build-isolation -e ".[metrics,deepspeed]"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -8,7 +8,7 @@ ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
|
||||
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
|
||||
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||
|
||||
RUN pip uninstall -y torch torchvision torch-tensorrt \
|
||||
flash_attn transformer-engine \
|
||||
@@ -56,14 +56,14 @@ ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
|
||||
# pip install LLaMA-Factory
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt /app/
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||
|
||||
COPY . /app/
|
||||
RUN pip install -e ".[metrics]" --no-build-isolation
|
||||
|
||||
# Expose port 7860 for LLaMA Board
|
||||
ENV GRADIO_SERVER_PORT=7860
|
||||
EXPOSE 7860
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
@@ -5,7 +5,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=torch-npu,metrics
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||
|
||||
@@ -28,21 +27,15 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e ".[metrics]" --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
# VOLUME [ "/root/.cache/huggingface", "/app/shared_data", "/app/output" ]
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a2
|
||||
image: llamafactory:npu-a2
|
||||
volumes:
|
||||
@@ -36,7 +35,6 @@ services:
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: torch-npu,metrics
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
volumes:
|
||||
|
||||
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG EXTRAS=metrics
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
||||
@@ -28,21 +27,14 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools editables "hatchling>=1.18.0"
|
||||
|
||||
# Reinstall pytorch rocm
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install LLaMA Factory
|
||||
RUN pip install --no-cache-dir -e ".[${EXTRAS}]" --no-build-isolation
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre ".[metrics,deepspeed]" --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
@@ -5,7 +5,6 @@ services:
|
||||
context: ../..
|
||||
args:
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
EXTRAS: metrics
|
||||
container_name: llamafactory
|
||||
ports:
|
||||
- "7860:7860"
|
||||
|
||||
1
examples/requirements/adam-mini.txt
Normal file
1
examples/requirements/adam-mini.txt
Normal file
@@ -0,0 +1 @@
|
||||
adam-mini
|
||||
1
examples/requirements/apollo.txt
Normal file
1
examples/requirements/apollo.txt
Normal file
@@ -0,0 +1 @@
|
||||
apollo-torch
|
||||
1
examples/requirements/aqlm.txt
Normal file
1
examples/requirements/aqlm.txt
Normal file
@@ -0,0 +1 @@
|
||||
aqlm[gpu]>=1.1.0
|
||||
1
examples/requirements/badam.txt
Normal file
1
examples/requirements/badam.txt
Normal file
@@ -0,0 +1 @@
|
||||
badam>=1.2.1
|
||||
1
examples/requirements/bitsandbytes.txt
Normal file
1
examples/requirements/bitsandbytes.txt
Normal file
@@ -0,0 +1 @@
|
||||
bitsandbytes>=0.39.0
|
||||
1
examples/requirements/eetq.txt
Normal file
1
examples/requirements/eetq.txt
Normal file
@@ -0,0 +1 @@
|
||||
eetq
|
||||
2
examples/requirements/fp8-te.txt
Normal file
2
examples/requirements/fp8-te.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
transformer_engine[pytorch]>=2.0.0
|
||||
accelerate>=1.10.0
|
||||
2
examples/requirements/fp8.txt
Normal file
2
examples/requirements/fp8.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
torchao>=0.8.0
|
||||
accelerate>=1.10.0
|
||||
1
examples/requirements/galore.txt
Normal file
1
examples/requirements/galore.txt
Normal file
@@ -0,0 +1 @@
|
||||
galore-torch
|
||||
2
examples/requirements/gptq.txt
Normal file
2
examples/requirements/gptq.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
optimum>=1.24.0
|
||||
gptqmodel>=2.0.0
|
||||
1
examples/requirements/hqq.txt
Normal file
1
examples/requirements/hqq.txt
Normal file
@@ -0,0 +1 @@
|
||||
hqq
|
||||
1
examples/requirements/liger-kernel.txt
Normal file
1
examples/requirements/liger-kernel.txt
Normal file
@@ -0,0 +1 @@
|
||||
liger-kernel>=0.5.5
|
||||
8
examples/requirements/minicpm-v.txt
Normal file
8
examples/requirements/minicpm-v.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
soundfile
|
||||
torchvision
|
||||
torchaudio
|
||||
vector_quantize_pytorch
|
||||
vocos
|
||||
msgpack
|
||||
referencing
|
||||
jsonschema_specifications
|
||||
1
examples/requirements/openmind.txt
Normal file
1
examples/requirements/openmind.txt
Normal file
@@ -0,0 +1 @@
|
||||
openmind
|
||||
2
examples/requirements/sglang.txt
Normal file
2
examples/requirements/sglang.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
sglang[srt]>=0.4.5
|
||||
transformers==4.51.1
|
||||
1
examples/requirements/swanlab.txt
Normal file
1
examples/requirements/swanlab.txt
Normal file
@@ -0,0 +1 @@
|
||||
swanlab
|
||||
1
examples/requirements/vllm.txt
Normal file
1
examples/requirements/vllm.txt
Normal file
@@ -0,0 +1 @@
|
||||
vllm>=0.4.3,<=0.11.0
|
||||
155
pyproject.toml
155
pyproject.toml
@@ -1,42 +1,123 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "llamafactory"
|
||||
requires-python = ">=3.9.0"
|
||||
dynamic = [
|
||||
"version",
|
||||
"dependencies",
|
||||
"optional-dependencies",
|
||||
"scripts",
|
||||
"authors",
|
||||
"description",
|
||||
"readme",
|
||||
"license",
|
||||
"keywords",
|
||||
"classifiers"
|
||||
dynamic = ["version"]
|
||||
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.11.0"
|
||||
authors = [
|
||||
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
||||
]
|
||||
keywords = [
|
||||
"AI",
|
||||
"LLM",
|
||||
"GPT",
|
||||
"ChatGPT",
|
||||
"Llama",
|
||||
"Transformer",
|
||||
"DeepSeek",
|
||||
"Pytorch"
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||
]
|
||||
dependencies = [
|
||||
# core deps
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
|
||||
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.14.0,<=0.17.1",
|
||||
"trl>=0.8.6,<=0.9.6",
|
||||
"torchdata>=0.10.0,<=0.11.0",
|
||||
# gui
|
||||
"gradio>=4.38.0,<=6.2.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"tyro<0.9.0",
|
||||
# ops
|
||||
"einops",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
# model and tokenizer
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
"modelscope",
|
||||
"hf-transfer",
|
||||
"safetensors",
|
||||
# python
|
||||
"av",
|
||||
"fire",
|
||||
"omegaconf",
|
||||
"packaging",
|
||||
"protobuf",
|
||||
"pyyaml",
|
||||
"pydantic",
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pre-commit", "ruff", "pytest", "build"]
|
||||
metrics = ["nltk", "jieba", "rouge-chinese"]
|
||||
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
|
||||
|
||||
[project.scripts]
|
||||
llamafactory-cli = "llamafactory.cli:main"
|
||||
lmf = "llamafactory.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
|
||||
Repository = "https://github.com/hiyouga/LLaMA-Factory"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/llamafactory"]
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/llamafactory/extras/env.py"
|
||||
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py311"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = [
|
||||
"C408", # collection
|
||||
"C901", # complex
|
||||
"E501", # line too long
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
"D103", # no doc public function
|
||||
"D104", # no doc public package
|
||||
"D105", # no doc magic method
|
||||
"D107", # no doc __init__
|
||||
"C408", # collection
|
||||
"C901", # complex
|
||||
"E501", # line too long
|
||||
"E731", # lambda function
|
||||
"E741", # ambiguous var name
|
||||
"UP007", # no upgrade union
|
||||
"UP045", # no upgrade optional
|
||||
"D100", # no doc public module
|
||||
"D101", # no doc public class
|
||||
"D102", # no doc public method
|
||||
"D103", # no doc public function
|
||||
"D104", # no doc public package
|
||||
"D105", # no doc magic method
|
||||
"D107", # no doc __init__
|
||||
]
|
||||
extend-select = [
|
||||
"C", # complexity
|
||||
@@ -73,23 +154,3 @@ indent-style = "space"
|
||||
docstring-code-format = true
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "aqlm" },
|
||||
],
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "vllm" },
|
||||
],
|
||||
[
|
||||
{ extra = "torch-npu" },
|
||||
{ extra = "sglang" },
|
||||
],
|
||||
[
|
||||
{ extra = "vllm" },
|
||||
{ extra = "sglang" },
|
||||
],
|
||||
]
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# core deps
|
||||
transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'
|
||||
transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'
|
||||
datasets>=2.16.0,<=4.0.0
|
||||
accelerate>=1.3.0,<=1.11.0
|
||||
peft>=0.14.0,<=0.17.1
|
||||
trl>=0.8.6,<=0.9.6
|
||||
torchdata
|
||||
# gui
|
||||
gradio>=4.38.0,<=5.45.0
|
||||
matplotlib>=3.7.0
|
||||
tyro<0.9.0
|
||||
# ops
|
||||
einops
|
||||
numpy<2.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
# model and tokenizer
|
||||
sentencepiece
|
||||
tiktoken
|
||||
modelscope>=1.14.0
|
||||
hf-transfer
|
||||
safetensors<=0.5.3
|
||||
# python
|
||||
fire
|
||||
omegaconf
|
||||
packaging
|
||||
protobuf
|
||||
pyyaml
|
||||
pydantic<=2.10.6
|
||||
# api
|
||||
uvicorn
|
||||
fastapi
|
||||
sse-starlette
|
||||
# media
|
||||
av
|
||||
librosa
|
||||
# yanked
|
||||
propcache!=0.4.0
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
):
|
||||
"""Convert megatron checkpoint to HuggingFace format.
|
||||
|
||||
@@ -67,11 +66,11 @@ def convert(
|
||||
output_path: str = "./output",
|
||||
bf16: bool = False,
|
||||
fp16: bool = False,
|
||||
convert_model_max_length: Optional[int] = None,
|
||||
convert_model_max_length: int | None = None,
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -61,7 +61,7 @@ def calculate_ppl(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import fire
|
||||
@@ -49,7 +48,7 @@ def vllm_infer(
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 2048,
|
||||
max_samples: Optional[int] = None,
|
||||
max_samples: int | None = None,
|
||||
vllm_config: str = "{}",
|
||||
save_name: str = "generated_predictions.jsonl",
|
||||
temperature: float = 0.95,
|
||||
@@ -58,9 +57,9 @@ def vllm_infer(
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
skip_special_tokens: bool = True,
|
||||
default_system: Optional[str] = None,
|
||||
default_system: str | None = None,
|
||||
enable_thinking: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
seed: int | None = None,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_max_pixels: int = 768 * 768,
|
||||
image_min_pixels: int = 32 * 32,
|
||||
|
||||
116
setup.py
116
setup.py
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
return version
|
||||
|
||||
|
||||
def get_requires() -> list[str]:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
|
||||
|
||||
def get_console_scripts() -> list[str]:
|
||||
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
|
||||
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
|
||||
console_scripts.append("lmf = llamafactory.cli:main")
|
||||
|
||||
return console_scripts
|
||||
|
||||
|
||||
extra_require = {
|
||||
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
|
||||
"torch-npu": ["torch==2.7.1", "torch-npu==2.7.1", "torchvision==0.22.1", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
|
||||
"liger-kernel": ["liger-kernel>=0.5.5"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"hqq": ["hqq"],
|
||||
"eetq": ["eetq"],
|
||||
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.11.0"],
|
||||
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
|
||||
"galore": ["galore-torch"],
|
||||
"apollo": ["apollo-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
"minicpm_v": [
|
||||
"soundfile",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"vector_quantize_pytorch",
|
||||
"vocos",
|
||||
"msgpack",
|
||||
"referencing",
|
||||
"jsonschema_specifications",
|
||||
],
|
||||
"openmind": ["openmind"],
|
||||
"swanlab": ["swanlab"],
|
||||
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
|
||||
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
|
||||
"dev": ["pre-commit", "ruff", "pytest", "build"],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
setup(
|
||||
name="llamafactory",
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga@buaa.edu.cn",
|
||||
description="Unified Efficient Fine-Tuning of 100+ LLMs",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
|
||||
license="Apache 2.0 License",
|
||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.9.0",
|
||||
install_requires=get_requires(),
|
||||
extras_require=extra_require,
|
||||
entry_points={"console_scripts": get_console_scripts()},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -16,7 +16,7 @@ import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
api_key = os.getenv("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
|
||||
if api_key and (auth is None or auth.credentials != api_key):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@unique
|
||||
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
|
||||
|
||||
class FunctionAvailable(BaseModel):
|
||||
type: Literal["function", "code_interpreter"] = "function"
|
||||
function: Optional[FunctionDefinition] = None
|
||||
function: FunctionDefinition | None = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
@@ -77,35 +76,35 @@ class URL(BaseModel):
|
||||
|
||||
class MultimodalInputItem(BaseModel):
|
||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[URL] = None
|
||||
video_url: Optional[URL] = None
|
||||
audio_url: Optional[URL] = None
|
||||
text: str | None = None
|
||||
image_url: URL | None = None
|
||||
video_url: URL | None = None
|
||||
audio_url: URL | None = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Role
|
||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
content: str | list[MultimodalInputItem] | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[Role] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[list[FunctionCall]] = None
|
||||
role: Role | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[FunctionCall] | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
tools: Optional[list[FunctionAvailable]] = None
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: list[FunctionAvailable] | None = None
|
||||
do_sample: bool | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||
index: int
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[Finish] = None
|
||||
finish_reason: Finish | None = None
|
||||
|
||||
|
||||
class ChatCompletionResponseUsage(BaseModel):
|
||||
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
class ScoreEvaluationRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[str]
|
||||
max_length: Optional[int] = None
|
||||
max_length: int | None = None
|
||||
|
||||
|
||||
class ScoreEvaluationResponse(BaseModel):
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
@@ -40,7 +40,7 @@ class DatasetConverter:
|
||||
dataset_attr: "DatasetAttr"
|
||||
data_args: "DataArguments"
|
||||
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
|
||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||
if medias is None:
|
||||
return None
|
||||
|
||||
@@ -16,7 +16,6 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
tool_format: str | None = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""Forms a list of slots according to the inputs to encode."""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
r"""Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
@@ -156,5 +155,5 @@ class ToolFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
||||
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
@@ -162,13 +162,13 @@ def _load_single_dataset(
|
||||
|
||||
|
||||
def _get_merged_dataset(
|
||||
dataset_names: Optional[list[str]],
|
||||
dataset_names: list[str] | None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
return_dict: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
||||
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
|
||||
r"""Return the merged datasets in the standard format."""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
@@ -227,7 +227,7 @@ def _get_dataset_processor(
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
dataset: Union["Dataset", "IterableDataset"] | None,
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
) -> Union["Dataset", "IterableDataset"] | None:
|
||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@@ -22,28 +22,20 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from typing_extensions import NotRequired, override
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import (
|
||||
is_librosa_available,
|
||||
is_pillow_available,
|
||||
is_pyav_available,
|
||||
is_transformers_version_greater_than,
|
||||
)
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -71,8 +63,8 @@ if TYPE_CHECKING:
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
path: str | None
|
||||
bytes: bytes | None
|
||||
|
||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||
@@ -152,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
||||
|
||||
@dataclass
|
||||
class MMPluginMixin:
|
||||
image_token: Optional[str]
|
||||
video_token: Optional[str]
|
||||
audio_token: Optional[str]
|
||||
image_token: str | None
|
||||
video_token: str | None
|
||||
audio_token: str | None
|
||||
expand_mm_tokens: bool = True
|
||||
|
||||
def _validate_input(
|
||||
@@ -316,7 +308,14 @@ class MMPluginMixin:
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio, sampling_rate = librosa.load(audio, sr=sampling_rate)
|
||||
audio, sr = torchaudio.load(audio)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, sampling_rate)
|
||||
|
||||
audio = audio.squeeze(0).numpy()
|
||||
|
||||
results.append(audio)
|
||||
sampling_rates.append(sampling_rate)
|
||||
@@ -329,7 +328,7 @@ class MMPluginMixin:
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
imglens: Optional[list[int]] = None,
|
||||
imglens: list[int] | None = None,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
r"""Process visual inputs.
|
||||
|
||||
@@ -427,13 +426,13 @@ class BasePlugin(MMPluginMixin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
r"""Pre-process token ids after tokenization for VLMs."""
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
@@ -480,21 +479,39 @@ class ErnieVLPlugin(BasePlugin):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
|
||||
image_idx, video_idx = 0, 0
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
image_token = self.image_token or "<|image@placeholder|>"
|
||||
video_token = self.video_token or "<|video@placeholder|>"
|
||||
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
|
||||
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>",
|
||||
1,
|
||||
)
|
||||
image_idx += 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
|
||||
)
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_idx += 1
|
||||
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>",
|
||||
1,
|
||||
)
|
||||
video_idx += 1
|
||||
message["content"] = content
|
||||
return messages
|
||||
|
||||
@@ -1288,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
labels: Optional[list[int]],
|
||||
labels: list[int] | None,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> tuple[list[int], Optional[list[int]]]:
|
||||
) -> tuple[list[int], list[int] | None]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
||||
@@ -2109,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
||||
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
audio_token: Optional[str] = None,
|
||||
image_token: str | None = None,
|
||||
video_token: str | None = None,
|
||||
audio_token: str | None = None,
|
||||
**kwargs,
|
||||
) -> "BasePlugin":
|
||||
r"""Get plugin for multimodal inputs."""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -33,40 +33,40 @@ class DatasetAttr:
|
||||
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
||||
ranking: bool = False
|
||||
# extra configs
|
||||
subset: Optional[str] = None
|
||||
subset: str | None = None
|
||||
split: str = "train"
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
folder: str | None = None
|
||||
num_samples: int | None = None
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
audios: Optional[str] = None
|
||||
system: str | None = None
|
||||
tools: str | None = None
|
||||
images: str | None = None
|
||||
videos: str | None = None
|
||||
audios: str | None = None
|
||||
# dpo columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
chosen: str | None = None
|
||||
rejected: str | None = None
|
||||
kto_tag: str | None = None
|
||||
# alpaca columns
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
prompt: str | None = "instruction"
|
||||
query: str | None = "input"
|
||||
response: str | None = "output"
|
||||
history: str | None = None
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
messages: str | None = "conversations"
|
||||
# sharegpt tags
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
role_tag: str | None = "from"
|
||||
content_tag: str | None = "value"
|
||||
user_tag: str | None = "human"
|
||||
assistant_tag: str | None = "gpt"
|
||||
observation_tag: str | None = "observation"
|
||||
function_tag: str | None = "function_call"
|
||||
system_tag: str | None = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
def join(self, attr: dict[str, Any]) -> None:
|
||||
@@ -90,7 +90,7 @@ class DatasetAttr:
|
||||
self.set_attr(tag, attr["tags"])
|
||||
|
||||
|
||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
||||
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
|
||||
r"""Get the attributes of the datasets."""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
@@ -981,7 +981,7 @@ register_template(
|
||||
replace_eos=True,
|
||||
replace_jinja_template=True,
|
||||
template_class=ReasoningTemplate,
|
||||
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"),
|
||||
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from typing import Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||
@@ -154,7 +153,7 @@ class RopeScaling(str, Enum):
|
||||
|
||||
def register_model_group(
|
||||
models: dict[str, dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
template: str | None = None,
|
||||
multimodal: bool = False,
|
||||
) -> None:
|
||||
for name, path in models.items():
|
||||
|
||||
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -332,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||
if ipv6_enabled:
|
||||
os.environ.pop("http_proxy", None)
|
||||
os.environ.pop("HTTP_PROXY", None)
|
||||
os.environ.pop("https_proxy", None)
|
||||
os.environ.pop("HTTPS_PROXY", None)
|
||||
os.environ.pop("all_proxy", None)
|
||||
os.environ.pop("ALL_PROXY", None)
|
||||
|
||||
@@ -16,22 +16,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||
|
||||
template: Optional[str] = field(
|
||||
template: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
eval_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -39,7 +39,7 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
media_dir: Optional[str] = field(
|
||||
media_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||
)
|
||||
@@ -67,7 +67,7 @@ class DataArguments:
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
interleave_probs: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||
)
|
||||
@@ -79,15 +79,15 @@ class DataArguments:
|
||||
default=1000,
|
||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
preprocessing_num_workers: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
max_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||
)
|
||||
eval_num_beams: Optional[int] = field(
|
||||
eval_num_beams: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||
)
|
||||
@@ -103,7 +103,7 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
packing: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
@@ -111,19 +111,19 @@ class DataArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
tool_format: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
default_system: Optional[str] = field(
|
||||
default_system: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override the default system message in the template."},
|
||||
)
|
||||
enable_thinking: Optional[bool] = field(
|
||||
enable_thinking: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
tokenized_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."},
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
save_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_extra_modules: Optional[str] = field(
|
||||
freeze_extra_modules: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
||||
class LoraArguments:
|
||||
r"""Arguments pertaining to the LoRA training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
lora_alpha: Optional[int] = field(
|
||||
lora_alpha: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||
)
|
||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
loraplus_lr_ratio: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
||||
class OFTArguments:
|
||||
r"""Arguments pertaining to the OFT training."""
|
||||
|
||||
additional_target: Optional[str] = field(
|
||||
additional_target: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||
)
|
||||
ref_model: Optional[str] = field(
|
||||
ref_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||
)
|
||||
ref_model_adapters: Optional[str] = field(
|
||||
ref_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reference model."},
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
ref_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."},
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
reward_model: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||
)
|
||||
reward_model_adapters: Optional[str] = field(
|
||||
reward_model_adapters: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapters of the reward model."},
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
reward_model_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."},
|
||||
)
|
||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
||||
default="lora",
|
||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
ld_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
||||
default="layer",
|
||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||
)
|
||||
badam_start_block: Optional[int] = field(
|
||||
badam_start_block: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||
default="ascending",
|
||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||
)
|
||||
badam_switch_interval: Optional[int] = field(
|
||||
badam_switch_interval: int | None = field(
|
||||
default=50,
|
||||
metadata={
|
||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||
)
|
||||
swanlab_project: Optional[str] = field(
|
||||
swanlab_project: str | None = field(
|
||||
default="llamafactory",
|
||||
metadata={"help": "The project name in SwanLab."},
|
||||
)
|
||||
swanlab_workspace: Optional[str] = field(
|
||||
swanlab_workspace: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The workspace name in SwanLab."},
|
||||
)
|
||||
swanlab_run_name: Optional[str] = field(
|
||||
swanlab_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The experiment name in SwanLab."},
|
||||
)
|
||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
||||
default="cloud",
|
||||
metadata={"help": "The mode of SwanLab."},
|
||||
)
|
||||
swanlab_api_key: Optional[str] = field(
|
||||
swanlab_api_key: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The API key for SwanLab."},
|
||||
)
|
||||
swanlab_logdir: Optional[str] = field(
|
||||
swanlab_logdir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The log directory for SwanLab."},
|
||||
)
|
||||
swanlab_lark_webhook_url: Optional[str] = field(
|
||||
swanlab_lark_webhook_url: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||
)
|
||||
swanlab_lark_secret: Optional[str] = field(
|
||||
swanlab_lark_secret: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||
)
|
||||
@@ -510,7 +510,7 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||
)
|
||||
early_stopping_steps: Optional[int] = field(
|
||||
early_stopping_steps: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||
)
|
||||
@@ -530,11 +530,11 @@ class FinetuningArguments(
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
||||
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Self
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||
},
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
adapter_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -50,11 +49,11 @@ class BaseModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
adapter_folder: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||
)
|
||||
@@ -70,17 +69,17 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
add_tokens: Optional[str] = field(
|
||||
add_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||
},
|
||||
)
|
||||
add_special_tokens: Optional[str] = field(
|
||||
add_special_tokens: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
new_special_tokens_config: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -110,7 +109,7 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
rope_scaling: RopeScaling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
@@ -122,7 +121,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
)
|
||||
@@ -138,7 +137,7 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
moe_aux_loss_coef: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
)
|
||||
@@ -182,15 +181,15 @@ class BaseModelArguments:
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
hf_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
)
|
||||
ms_hub_token: Optional[str] = field(
|
||||
ms_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
om_hub_token: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class QuantizationArguments:
|
||||
default=QuantizationMethod.BNB,
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||
)
|
||||
@@ -295,7 +294,7 @@ class QuantizationArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||
)
|
||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
||||
quantization_device_map: Literal["auto"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
@@ -375,7 +374,7 @@ class ProcessorArguments:
|
||||
class ExportArguments:
|
||||
r"""Arguments pertaining to the model export."""
|
||||
|
||||
export_dir: Optional[str] = field(
|
||||
export_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
@@ -387,11 +386,11 @@ class ExportArguments:
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
export_quantization_bit: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the exported model."},
|
||||
)
|
||||
export_quantization_dataset: Optional[str] = field(
|
||||
export_quantization_dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||
)
|
||||
@@ -407,7 +406,7 @@ class ExportArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||
)
|
||||
export_hub_model_id: Optional[str] = field(
|
||||
export_hub_model_id: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||
)
|
||||
@@ -437,7 +436,7 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
vllm_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -463,7 +462,7 @@ class SGLangArguments:
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
sglang_config: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
@@ -487,21 +486,21 @@ class KTransformersArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||
)
|
||||
kt_optimize_rule: Optional[str] = field(
|
||||
kt_optimize_rule: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||
},
|
||||
)
|
||||
cpu_infer: Optional[int] = field(
|
||||
cpu_infer: int | None = field(
|
||||
default=32,
|
||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||
)
|
||||
chunk_size: Optional[int] = field(
|
||||
chunk_size: int | None = field(
|
||||
default=8192,
|
||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
mode: str | None = field(
|
||||
default="normal",
|
||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||
)
|
||||
@@ -539,17 +538,17 @@ class ModelArguments(
|
||||
The class on the most right will be displayed first.
|
||||
"""
|
||||
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
compute_dtype: torch.dtype | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
||||
device_map: str | dict[str, Any] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
model_max_length: int | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -65,7 +65,7 @@ else:
|
||||
_TRAIN_MCA_CLS = tuple()
|
||||
|
||||
|
||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||
r"""Get arguments from the command line or a config file."""
|
||||
if args is not None:
|
||||
return args
|
||||
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
||||
|
||||
|
||||
def _parse_args(
|
||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
||||
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||
) -> tuple[Any]:
|
||||
args = read_args(args)
|
||||
if isinstance(args, dict):
|
||||
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
|
||||
finetuning_args.use_mca = True
|
||||
|
||||
|
||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
parser = HfArgumentParser(_INFER_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
parser = HfArgumentParser(_EVAL_ARGS)
|
||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||
|
||||
|
||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
||||
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||
parser = HfArgumentParser(RayArguments)
|
||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||
return ray_args
|
||||
|
||||
|
||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||
if is_env_enabled("USE_MCA"):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||
else:
|
||||
@@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||
|
||||
# Setup logging
|
||||
@@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
||||
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||
|
||||
# Setup logging
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -40,7 +40,7 @@ else:
|
||||
class RayArguments:
|
||||
r"""Arguments pertaining to the Ray training."""
|
||||
|
||||
ray_run_name: Optional[str] = field(
|
||||
ray_run_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class RayArguments:
|
||||
default="./saves",
|
||||
metadata={"help": "The storage path to save training results to"},
|
||||
)
|
||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
||||
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||
)
|
||||
@@ -56,7 +56,7 @@ class RayArguments:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
)
|
||||
resources_per_worker: Union[dict, str] = field(
|
||||
resources_per_worker: dict | str = field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
)
|
||||
@@ -64,7 +64,7 @@ class RayArguments:
|
||||
default="PACK",
|
||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||
)
|
||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
||||
ray_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -158,6 +157,7 @@ def load_model(
|
||||
if model is None and not lazy_load:
|
||||
init_kwargs["config"] = config
|
||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||
init_kwargs["torch_dtype"] = "auto"
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
@@ -205,10 +205,6 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
for param in model.parameters():
|
||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
@@ -20,9 +20,10 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -156,16 +156,13 @@ def patch_config(
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
# do not cast data type of the model deepspeed zero3 without qlora
|
||||
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
# fsdp/deepspeed zero3 does not need device map
|
||||
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
||||
if "device_map" not in init_kwargs and model_args.device_map:
|
||||
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
|
||||
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
if init_kwargs.get("device_map", None) == "auto":
|
||||
init_kwargs["offload_folder"] = model_args.offload_folder
|
||||
|
||||
|
||||
def patch_model(
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
class ComputeAccuracy:
|
||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||
|
||||
def _dump(self) -> Optional[dict[str, float]]:
|
||||
def _dump(self) -> dict[str, float] | None:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
@@ -39,7 +39,7 @@ class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
|
||||
@@ -84,8 +84,6 @@ def load_reference_model(
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
if not is_trainable:
|
||||
model.v_head = model.v_head.to(torch.float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
@@ -25,10 +25,11 @@ Including:
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -53,9 +53,9 @@ class DistributedStrategy:
|
||||
|
||||
mp_replicate_size: int = 1
|
||||
"""Model parallel replicate size, default to 1."""
|
||||
mp_shard_size: Optional[int] = None
|
||||
mp_shard_size: int | None = None
|
||||
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||
dp_size: Optional[int] = None
|
||||
dp_size: int | None = None
|
||||
"""Data parallel size, default to world_size // cp_size."""
|
||||
cp_size: int = 1
|
||||
"""Context parallel size, default to 1."""
|
||||
@@ -115,7 +115,7 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
def __init__(self, config: DistributedConfig | None = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
@@ -165,7 +165,7 @@ class DistributedInterface:
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
@@ -176,14 +176,14 @@ class DistributedInterface:
|
||||
else:
|
||||
return self.model_device_mesh[dim.value]
|
||||
|
||||
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
def get_rank(self, dim: Optional[Dim] = None) -> int:
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
return 0
|
||||
@@ -192,7 +192,7 @@ class DistributedInterface:
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_local_rank()
|
||||
|
||||
def get_world_size(self, dim: Optional[Dim] = None) -> int:
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
return 1
|
||||
@@ -209,7 +209,7 @@ class DistributedInterface:
|
||||
"""Get parallel local world size."""
|
||||
return self._local_world_size
|
||||
|
||||
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
@@ -217,7 +217,7 @@ class DistributedInterface:
|
||||
return data
|
||||
|
||||
def all_reduce(
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
@@ -225,7 +225,7 @@ class DistributedInterface:
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -33,7 +32,7 @@ class PluginConfig(dict):
|
||||
return self["name"]
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
||||
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
|
||||
"""Get the plugin configuration from the argument value.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,12 +14,11 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
@@ -36,15 +35,15 @@ class ModelArguments:
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: Optional[PluginConfig] = field(
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: Optional[PluginConfig] = field(
|
||||
kernel_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: Optional[PluginConfig] = field(
|
||||
quant_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
@@ -42,7 +41,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ Get Data Sample:
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
return len(self.data_index)
|
||||
|
||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
||||
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
|
||||
"""Get dataset item.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,7 +31,7 @@ class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
@@ -47,7 +46,7 @@ class QwenTemplate:
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
@@ -38,7 +39,7 @@ class KernelRegistry:
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
@@ -56,7 +57,7 @@ class KernelRegistry:
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
freeze_trainable_modules: list[str] | None
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||
@@ -38,7 +37,7 @@ class DtypeInterface:
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_available(precision: str | torch.dtype) -> bool:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return DtypeInterface._is_fp16_available
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -49,19 +48,19 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp32(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_bf16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
||||
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -83,7 +82,7 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
||||
def set_dtype(self, precision: str | torch.dtype):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.to_dtype(precision))
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
|
||||
@@ -29,7 +29,7 @@ class BasePlugin:
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.engine: Optional[BaseEngine] = None
|
||||
self.engine: BaseEngine | None = None
|
||||
|
||||
if not lazy_init: # read arguments from command line
|
||||
super().__init__()
|
||||
@@ -197,9 +197,9 @@ class WebChatModel(ChatModel):
|
||||
lang: str,
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[Any],
|
||||
video: Optional[Any],
|
||||
audio: Optional[Any],
|
||||
image: Any | None,
|
||||
video: Any | None,
|
||||
audio: Any | None,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
import signal
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from psutil import Process
|
||||
from yaml import safe_dump, safe_load
|
||||
@@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
def load_config() -> dict[str, str | dict[str, Any]]:
|
||||
r"""Load user config if exists."""
|
||||
try:
|
||||
with open(_get_config_path(), encoding="utf-8") as f:
|
||||
@@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
|
||||
|
||||
def save_config(
|
||||
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
|
||||
lang: str, hub_name: str | None = None, model_name: str | None = None, model_path: str | None = None
|
||||
) -> None:
|
||||
r"""Save user config."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
@@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
|
||||
return {}
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[dict[str, Any]]:
|
||||
def load_args(config_path: str) -> dict[str, Any] | None:
|
||||
r"""Load the training configuration from config path."""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
|
||||
def can_quantize(checkpoint_path: str | list[str]) -> "gr.Dropdown":
|
||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
else:
|
||||
@@ -49,7 +49,7 @@ def save_model(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
finetuning_type: str,
|
||||
checkpoint_path: Union[str, list[str]],
|
||||
checkpoint_path: str | list[str],
|
||||
template: str,
|
||||
export_size: int,
|
||||
export_quantization_bit: str,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
@@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
|
||||
return gr.Dropdown(choices=datasets)
|
||||
|
||||
|
||||
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
def list_output_dirs(model_name: str | None, finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
r"""List all the directories that can resume from.
|
||||
|
||||
Inputs: top.model_name, top.finetuning_type, train.current_time
|
||||
|
||||
@@ -35,35 +35,40 @@ LOCALES = {
|
||||
"value": (
|
||||
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub Page</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"Documentation</a></center></h3>"
|
||||
"Documentation</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"Blog</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ru": {
|
||||
"value": (
|
||||
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"страницу GitHub</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"Документацию</a></center></h3>"
|
||||
"Документацию</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"Блог</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"zh": {
|
||||
"value": (
|
||||
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 主页</a> <a href='https://llamafactory.readthedocs.io/zh-cn/latest/' target='_blank'>"
|
||||
"官方文档</a></center></h3>"
|
||||
"官方文档</a> <a href='https://blog.llamafactory.net/' target='_blank'>"
|
||||
"博客</a></center></h3>"
|
||||
),
|
||||
},
|
||||
"ko": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub 페이지</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"공식 문서</a>를 방문하세요.</center></h3>"
|
||||
"공식 문서</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"블로그</a>를 방문하세요.</center></h3>"
|
||||
),
|
||||
},
|
||||
"ja": {
|
||||
"value": (
|
||||
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
|
||||
"GitHub ページ</a> <a href='https://llamafactory.readthedocs.io/en/latest/' target='_blank'>"
|
||||
"ドキュメント</a>にアクセスする</center></h3>"
|
||||
"ドキュメント</a> <a href='https://blog.llamafactory.net/en/' target='_blank'>"
|
||||
"ブログ</a>にアクセスする</center></h3>"
|
||||
),
|
||||
},
|
||||
},
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from subprocess import PIPE, Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
@@ -59,7 +59,7 @@ class Runner:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.trainer: Optional[Popen] = None
|
||||
self.trainer: Popen | None = None
|
||||
self.do_train = True
|
||||
self.running_data: dict[Component, Any] = None
|
||||
""" State """
|
||||
|
||||
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
Reference in New Issue
Block a user