From 706b3e5ee7189dfff7dd502fc2c206ce5a86e5da Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Wed, 6 Aug 2025 05:56:46 +0800 Subject: [PATCH] [model] add gpt oss (#8826) --- .github/workflows/tests.yml | 5 +++ README.md | 9 ++++-- README_zh.md | 9 ++++-- examples/train_lora/gpt_lora_sft.yaml | 46 +++++++++++++++++++++++++++ requirements.txt | 3 +- src/llamafactory/data/template.py | 10 ++++++ src/llamafactory/extras/constants.py | 15 +++++++++ src/llamafactory/extras/misc.py | 8 ++--- src/llamafactory/model/loader.py | 6 ++-- tests/version.txt | 2 +- 10 files changed, 97 insertions(+), 16 deletions(-) create mode 100644 examples/train_lora/gpt_lora_sft.yaml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1dc6230f..6d4ad682 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,6 +72,11 @@ jobs: run: | python -m pip install "transformers==${{ matrix.transformers }}" + - name: Install transformers to avoid mac os ci errors + if: ${{ matrix.os == 'macos-13' }} + run: | + python -m pip install "transformers<=4.51.3" + - name: Cache files id: hf-hub-cache uses: actions/cache@v4 diff --git a/README.md b/README.md index 2f82b21d..2daeb2e3 100644 --- a/README.md +++ b/README.md @@ -118,10 +118,14 @@ Choose your path: ## Changelog -[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. Please install transformers from **main** branch to use. +[25/08/06] We supported fine-tuning the **[GPT-OSS](https://github.com/openai/gpt-oss)** models. See [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) to get started. + +[25/07/02] We supported fine-tuning the **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** model. [25/04/28] We supported fine-tuning the **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** model family. +
Full Changelog + [25/04/21] We supported the **[Muon](https://github.com/KellerJordan/Muon)** optimizer. See [examples](examples/README.md) for usage. Thank [@tianshijing](https://github.com/tianshijing)'s PR. [25/04/16] We supported fine-tuning the **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** model. See [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) to get started. @@ -130,8 +134,6 @@ Choose your path: [25/04/06] We supported fine-tuning the **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** model. See [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) to get started. -
Full Changelog - [25/03/31] We supported fine-tuning the **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** model. See [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) to get started. [25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. @@ -268,6 +270,7 @@ Choose your path: | [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v | | [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | +| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | diff --git a/README_zh.md b/README_zh.md index acf6d6e5..83f8fc98 100644 --- a/README_zh.md +++ b/README_zh.md @@ -120,10 +120,14 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc ## 更新日志 -[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。请安装 transformers 的 main 分支版本以使用。 +[25/08/06] 我们支持了 **[GPT-OSS](https://github.com/openai/gpt-oss)** 模型的微调。查看 [PR #8826](https://github.com/hiyouga/LLaMA-Factory/pull/8826) 以使用。 + +[25/07/02] 我们支持了 **[GLM-4.1V-9B-Thinking](https://github.com/THUDM/GLM-4.1V-Thinking)** 模型的微调。 [25/04/28] 我们支持了 **[Qwen3](https://qwenlm.github.io/blog/qwen3/)** 系列模型的微调。 +
展开日志 + [25/04/21] 我们支持了 **[Muon](https://github.com/KellerJordan/Muon)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@tianshijing](https://github.com/tianshijing) 的 PR。 [25/04/16] 我们支持了 **[InternVL3](https://huggingface.co/OpenGVLab/InternVL3-8B)** 模型的微调。查看 [PR #7258](https://github.com/hiyouga/LLaMA-Factory/pull/7258) 以使用。 @@ -132,8 +136,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc [25/04/06] 我们支持了 **[Llama 4](https://ai.meta.com/blog/llama-4-multimodal-intelligence/)** 模型的微调。查看 [PR #7611](https://github.com/hiyouga/LLaMA-Factory/pull/7611) 以使用。 -
展开日志 - [25/03/31] 我们支持了 **[Qwen2.5 Omni](https://qwenlm.github.io/blog/qwen2.5-omni/)** 模型的微调。查看 [PR #7537](https://github.com/hiyouga/LLaMA-Factory/pull/7537) 以使用。 [25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。 @@ -270,6 +272,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [GLM-4.1V](https://huggingface.co/zai-org)* | 9B | glm4v | | [GLM-4.5](https://huggingface.co/zai-org)* | 106B/355B | glm4_moe | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | +| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt | | [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 4](https://huggingface.co/ibm-granite) | 7B | granite4 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | diff --git a/examples/train_lora/gpt_lora_sft.yaml b/examples/train_lora/gpt_lora_sft.yaml new file mode 100644 index 00000000..b07615b1 --- /dev/null +++ b/examples/train_lora/gpt_lora_sft.yaml @@ -0,0 +1,46 @@ +### model +model_name_or_path: openai/gpt-oss-20b +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_rank: 8 +lora_target: all + +### dataset +dataset: identity,alpaca_en_demo +template: gpt +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/gpt-20b/lora/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 +resume_from_checkpoint: null + +### eval +# eval_dataset: alpaca_en_demo +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/requirements.txt b/requirements.txt index 4cc5279e..06999de6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ # core deps -transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin' -transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin' +transformers>=4.49.0,<=4.55.0,!=4.52.0 datasets>=2.16.0,<=3.6.0 accelerate>=1.3.0,<=1.7.0 peft>=0.14.0,<=0.15.2 diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 51fc2b02..70f5e435 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1063,6 +1063,16 @@ register_template( ) +register_template( + name="gpt", + format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]), + format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), + format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), + default_system="You are ChatGPT, a large language model trained by OpenAI.", + efficient_eos=True, +) + + register_template( name="granite3", format_user=StringFormatter( diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6b1a8f91..077a3497 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -945,6 +945,21 @@ register_model_group( ) +register_model_group( + models={ + "GPT-OSS-20B-Thinking": { + DownloadSource.DEFAULT: "openai/gpt-oss-20b", + DownloadSource.MODELSCOPE: "openai/gpt-oss-20b", + }, + "GPT-OSS-120B-Thinking": { + DownloadSource.DEFAULT: "openai/gpt-oss-120b", + DownloadSource.MODELSCOPE: "openai/gpt-oss-120b", + }, + }, + template="gpt", +) + + register_model_group( models={ "Granite-3.0-1B-A400M-Base": { diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 53f742ea..65eedd80 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -18,7 +18,7 @@ import gc import os import socket -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch import torch.distributed as dist @@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0") + check_version("transformers>=4.49.0,<=4.55.0") check_version("datasets>=2.16.0,<=3.6.0") check_version("accelerate>=1.3.0,<=1.7.0") check_version("peft>=0.14.0,<=0.15.2") @@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool: return os.path.isdir(path) and len(os.listdir(path)) > 0 -def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": +def infer_optim_dtype(model_dtype: Optional["torch.dtype"]) -> "torch.dtype": r"""Infer the optimal dtype according to the model_dtype and device compatibility.""" - if _is_bf16_available and model_dtype == torch.bfloat16: + if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None): return torch.bfloat16 elif _is_fp16_available: return torch.float16 diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 7d7ee2e0..8793135f 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -156,10 +156,10 @@ def load_model( if model_args.mixture_of_depths == "load": model = load_mod_pretrained_model(**init_kwargs) else: - if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text - load_class = AutoModelForVision2Seq - elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text + if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text load_class = AutoModelForImageTextToText + elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text + load_class = AutoModelForVision2Seq elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text load_class = AutoModelForSeq2SeqLM elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni diff --git a/tests/version.txt b/tests/version.txt index ff9e9d5e..7436a40d 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.4.100 +0.9.4.101