mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-09 23:22:52 +08:00
[model] add gpt oss (#8826)
This commit is contained in:
parent
48615ddb07
commit
706b3e5ee7
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@ -72,6 +72,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install "transformers==${{ matrix.transformers }}"
|
python -m pip install "transformers==${{ matrix.transformers }}"
|
||||||
|
|
||||||
|
- name: Install transformers to avoid mac os ci errors
|
||||||
|
if: ${{ matrix.os == 'macos-13' }}
|
||||||
|
run: |
|
||||||
|
python -m pip install "transformers<=4.51.3"
|
||||||
|
|
||||||
- name: Cache files
|
- name: Cache files
|
||||||
id: hf-hub-cache
|
id: hf-hub-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
|
@ -118,10 +118,14 @@ Choose your path:
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. Please install transformers from **main** branch to use.
|
[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started.
|
||||||
|
|
||||||
|
[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model.
|
||||||
|
|
||||||
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
|
[25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family.
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
|
[25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR.
|
||||||
|
|
||||||
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
|
[25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started.
|
||||||
@ -130,8 +134,6 @@ Choose your path:
|
|||||||
|
|
||||||
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
|
[25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
|
||||||
|
|
||||||
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
|
[25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started.
|
||||||
|
|
||||||
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
|
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
|
||||||
@ -268,6 +270,7 @@ Choose your path:
|
|||||||
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
||||||
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
||||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
|
@ -120,10 +120,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。请安装 transformers 的 main 分支版本以使用。
|
[25/08/06] 我们支持了 **[GPT-OSS](https://github.com/openai/gpt-oss)** 模型的微调。查看 [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) 以使用。
|
||||||
|
|
||||||
|
[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。
|
||||||
|
|
||||||
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
|
[25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
|
[25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。
|
||||||
|
|
||||||
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
|
[25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。
|
||||||
@ -132,8 +136,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
[25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。
|
[25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
|
[25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。
|
||||||
|
|
||||||
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
|
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
|
||||||
@ -270,6 +272,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
| [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v |
|
||||||
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
| [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe |
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt |
|
||||||
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
| [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
|
46
examples/train_lora/gpt_lora_sft.yaml
Normal file
46
examples/train_lora/gpt_lora_sft.yaml
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: openai/gpt-oss-20b
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: identity,alpaca_en_demo
|
||||||
|
template: gpt
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/gpt-20b/lora/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# eval_dataset: alpaca_en_demo
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
@ -1,6 +1,5 @@
|
|||||||
# core deps
|
# core deps
|
||||||
transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin'
|
transformers>=4.49.0,<=4.55.0,!=4.52.0
|
||||||
transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin'
|
|
||||||
datasets>=2.16.0,<=3.6.0
|
datasets>=2.16.0,<=3.6.0
|
||||||
accelerate>=1.3.0,<=1.7.0
|
accelerate>=1.3.0,<=1.7.0
|
||||||
peft>=0.14.0,<=0.15.2
|
peft>=0.14.0,<=0.15.2
|
||||||
|
@ -1063,6 +1063,16 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="gpt",
|
||||||
|
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||||
|
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||||
|
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||||
|
efficient_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="granite3",
|
name="granite3",
|
||||||
format_user=StringFormatter(
|
format_user=StringFormatter(
|
||||||
|
@ -945,6 +945,21 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"GPT-OSS-20B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "openai/gpt-oss-20b",
|
||||||
|
DownloadSource.MODELSCOPE: "openai/gpt-oss-20b",
|
||||||
|
},
|
||||||
|
"GPT-OSS-120B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "openai/gpt-oss-120b",
|
||||||
|
DownloadSource.MODELSCOPE: "openai/gpt-oss-120b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gpt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Granite-3.0-1B-A400M-Base": {
|
"Granite-3.0-1B-A400M-Base": {
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
|
check_version("transformers>=4.49.0,<=4.55.0")
|
||||||
check_version("datasets>=2.16.0,<=3.6.0")
|
check_version("datasets>=2.16.0,<=3.6.0")
|
||||||
check_version("accelerate>=1.3.0,<=1.7.0")
|
check_version("accelerate>=1.3.0,<=1.7.0")
|
||||||
check_version("peft>=0.14.0,<=0.15.2")
|
check_version("peft>=0.14.0,<=0.15.2")
|
||||||
@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
|
|||||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||||
|
|
||||||
|
|
||||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype":
|
||||||
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
|
||||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
elif _is_fp16_available:
|
elif _is_fp16_available:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
@ -156,10 +156,10 @@ def load_model(
|
|||||||
if model_args.mixture_of_depths == "load":
|
if model_args.mixture_of_depths == "load":
|
||||||
model = load_mod_pretrained_model(**init_kwargs)
|
model = load_mod_pretrained_model(**init_kwargs)
|
||||||
else:
|
else:
|
||||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
||||||
load_class = AutoModelForVision2Seq
|
|
||||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
|
||||||
load_class = AutoModelForImageTextToText
|
load_class = AutoModelForImageTextToText
|
||||||
|
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||||
|
load_class = AutoModelForVision2Seq
|
||||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||||
load_class = AutoModelForSeq2SeqLM
|
load_class = AutoModelForSeq2SeqLM
|
||||||
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.4.100
|
0.9.4.101
|
||||||
|
Loading…
x
Reference in New Issue
Block a user