diff --git a/README.md b/README.md
index 8e2f04f8..8aa27161 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[](https://pypi.org/project/llamafactory/)
-[](https://scholar.google.com/scholar?cites=12620864006390196564)
+[](https://scholar.google.com/scholar?cites=12620864006390196564)
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
[](https://twitter.com/llamafactory_ai)
@@ -112,7 +112,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference.
-[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model.
+[25/03/12] We supported fine-tuning the **[Gemma 3](https://huggingface.co/blog/gemma3)** model.
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
@@ -873,7 +873,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
-Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation
diff --git a/README_zh.md b/README_zh.md
index 0a03fb68..676eb234 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -5,7 +5,7 @@
[](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors)
[](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml)
[](https://pypi.org/project/llamafactory/)
-[](https://scholar.google.com/scholar?cites=12620864006390196564)
+[](https://scholar.google.com/scholar?cites=12620864006390196564)
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
[](https://twitter.com/llamafactory_ai)
@@ -114,7 +114,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。
-[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。
+[25/03/12] 我们支持了 **[Gemma 3](https://huggingface.co/blog/gemma3)** 模型的微调。
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
@@ -876,7 +876,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
-使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
+使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Llama 4](https://github.com/meta-llama/llama-models/blob/main/models/llama4/LICENSE) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用
diff --git a/examples/extras/adam_mini/qwen2_full_sft.yaml b/examples/extras/adam_mini/qwen2_full_sft.yaml
index e1c14e07..79df9a73 100644
--- a/examples/extras/adam_mini/qwen2_full_sft.yaml
+++ b/examples/extras/adam_mini/qwen2_full_sft.yaml
@@ -15,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/qwen2-1_5b/full/sft
@@ -22,6 +23,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/apollo/llama3_full_sft.yaml b/examples/extras/apollo/llama3_full_sft.yaml
index 9b782b67..d9fb6c20 100644
--- a/examples/extras/apollo/llama3_full_sft.yaml
+++ b/examples/extras/apollo/llama3_full_sft.yaml
@@ -20,6 +20,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/full/sft
@@ -27,6 +28,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/badam/llama3_full_sft.yaml b/examples/extras/badam/llama3_full_sft.yaml
index 57d6729d..7ce33230 100644
--- a/examples/extras/badam/llama3_full_sft.yaml
+++ b/examples/extras/badam/llama3_full_sft.yaml
@@ -20,6 +20,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/full/sft
@@ -27,6 +28,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
index f50fe725..1a8d9743 100644
--- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
+++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml
@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/galore/llama3_full_sft.yaml b/examples/extras/galore/llama3_full_sft.yaml
index 7540da4f..99730932 100644
--- a/examples/extras/galore/llama3_full_sft.yaml
+++ b/examples/extras/galore/llama3_full_sft.yaml
@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/full/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml
index 6d51ceee..6c5efb8b 100644
--- a/examples/extras/llama_pro/llama3_freeze_sft.yaml
+++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml
@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b-pro/freeze/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml
index cbecf632..574b4870 100644
--- a/examples/extras/loraplus/llama3_lora_sft.yaml
+++ b/examples/extras/loraplus/llama3_lora_sft.yaml
@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -24,6 +25,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml
index 4799a826..ed784e74 100644
--- a/examples/extras/mod/llama3_full_sft.yaml
+++ b/examples/extras/mod/llama3_full_sft.yaml
@@ -15,6 +15,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b-mod/full/sft
@@ -22,6 +23,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/extras/nlg_eval/llama3_lora_predict.yaml b/examples/extras/nlg_eval/llama3_lora_predict.yaml
index 2d0dfad3..be51c2e4 100644
--- a/examples/extras/nlg_eval/llama3_lora_predict.yaml
+++ b/examples/extras/nlg_eval/llama3_lora_predict.yaml
@@ -18,10 +18,12 @@ cutoff_len: 2048
max_samples: 50
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/predict
overwrite_output_dir: true
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### eval
per_device_eval_batch_size: 1
diff --git a/examples/extras/pissa/llama3_lora_sft.yaml b/examples/extras/pissa/llama3_lora_sft.yaml
index 5827f6a5..1668343b 100644
--- a/examples/extras/pissa/llama3_lora_sft.yaml
+++ b/examples/extras/pissa/llama3_lora_sft.yaml
@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_full/llama3_full_sft.yaml b/examples/train_full/llama3_full_sft.yaml
index 19d6df42..fb7066a7 100644
--- a/examples/train_full/llama3_full_sft.yaml
+++ b/examples/train_full/llama3_full_sft.yaml
@@ -24,6 +24,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml
index 41e9f49e..a2fb9bc4 100644
--- a/examples/train_full/qwen2vl_full_sft.yaml
+++ b/examples/train_full/qwen2vl_full_sft.yaml
@@ -8,10 +8,10 @@ trust_remote_code: true
stage: sft
do_train: true
finetuning_type: full
-freeze_vision_tower: true # choices: [true, false]
-freeze_multi_modal_projector: true # choices: [true, false]
-freeze_language_model: false # choices: [true, false]
-deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
+freeze_vision_tower: true
+freeze_multi_modal_projector: true
+freeze_language_model: false
+deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo,identity,alpaca_en_demo
@@ -29,6 +29,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_dpo.yaml b/examples/train_lora/llama3_lora_dpo.yaml
index 1b890ab2..fd8c042c 100644
--- a/examples/train_lora/llama3_lora_dpo.yaml
+++ b/examples/train_lora/llama3_lora_dpo.yaml
@@ -27,6 +27,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_kto.yaml b/examples/train_lora/llama3_lora_kto.yaml
index 980381ab..113b9129 100644
--- a/examples/train_lora/llama3_lora_kto.yaml
+++ b/examples/train_lora/llama3_lora_kto.yaml
@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/kto
@@ -24,6 +25,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_ppo.yaml b/examples/train_lora/llama3_lora_ppo.yaml
index 4efc3009..87944819 100644
--- a/examples/train_lora/llama3_lora_ppo.yaml
+++ b/examples/train_lora/llama3_lora_ppo.yaml
@@ -17,6 +17,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/ppo
@@ -24,6 +25,7 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_pretrain.yaml b/examples/train_lora/llama3_lora_pretrain.yaml
index 82e0d58a..3c851d70 100644
--- a/examples/train_lora/llama3_lora_pretrain.yaml
+++ b/examples/train_lora/llama3_lora_pretrain.yaml
@@ -24,6 +24,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_reward.yaml b/examples/train_lora/llama3_lora_reward.yaml
index e71a99b8..48230b55 100644
--- a/examples/train_lora/llama3_lora_reward.yaml
+++ b/examples/train_lora/llama3_lora_reward.yaml
@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_sft.yaml b/examples/train_lora/llama3_lora_sft.yaml
index fe889208..157d6610 100644
--- a/examples/train_lora/llama3_lora_sft.yaml
+++ b/examples/train_lora/llama3_lora_sft.yaml
@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_sft_ds3.yaml b/examples/train_lora/llama3_lora_sft_ds3.yaml
index b35f5466..e20b3517 100644
--- a/examples/train_lora/llama3_lora_sft_ds3.yaml
+++ b/examples/train_lora/llama3_lora_sft_ds3.yaml
@@ -26,6 +26,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/llama3_lora_sft_ray.yaml b/examples/train_lora/llama3_lora_sft_ray.yaml
index d30e986b..e7e4b390 100644
--- a/examples/train_lora/llama3_lora_sft_ray.yaml
+++ b/examples/train_lora/llama3_lora_sft_ray.yaml
@@ -26,6 +26,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### ray
ray_run_name: llama3_8b_sft_lora
diff --git a/examples/train_lora/llama4_lora_sft.yaml b/examples/train_lora/llama4_lora_sft_ds3.yaml
similarity index 90%
rename from examples/train_lora/llama4_lora_sft.yaml
rename to examples/train_lora/llama4_lora_sft_ds3.yaml
index f9123091..6c5bb7bb 100644
--- a/examples/train_lora/llama4_lora_sft.yaml
+++ b/examples/train_lora/llama4_lora_sft_ds3.yaml
@@ -28,10 +28,11 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
-gradient_accumulation_steps: 8
+gradient_accumulation_steps: 2
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
diff --git a/examples/train_lora/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml
index 116c2a42..63cdcaea 100644
--- a/examples/train_lora/llava1_5_lora_sft.yaml
+++ b/examples/train_lora/llava1_5_lora_sft.yaml
@@ -25,6 +25,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/qwen2vl_lora_dpo.yaml b/examples/train_lora/qwen2vl_lora_dpo.yaml
index 148c4ec2..3c990b42 100644
--- a/examples/train_lora/qwen2vl_lora_dpo.yaml
+++ b/examples/train_lora/qwen2vl_lora_dpo.yaml
@@ -29,6 +29,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml
index c57b78e4..54ff9842 100644
--- a/examples/train_lora/qwen2vl_lora_sft.yaml
+++ b/examples/train_lora/qwen2vl_lora_sft.yaml
@@ -27,6 +27,7 @@ save_steps: 500
plot_loss: true
overwrite_output_dir: true
save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_qlora/llama3_lora_sft_aqlm.yaml b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
index c035a504..a7d44c7e 100644
--- a/examples/train_qlora/llama3_lora_sft_aqlm.yaml
+++ b/examples/train_qlora/llama3_lora_sft_aqlm.yaml
@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_qlora/llama3_lora_sft_awq.yaml b/examples/train_qlora/llama3_lora_sft_awq.yaml
index a9a970f0..861edfde 100644
--- a/examples/train_qlora/llama3_lora_sft_awq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_awq.yaml
@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml b/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
index 5b768a98..d68ce665 100644
--- a/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
+++ b/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
@@ -1,7 +1,7 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
quantization_bit: 4
-quantization_method: bitsandbytes
+quantization_method: bnb
double_quantization: false
trust_remote_code: true
@@ -19,6 +19,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -26,6 +27,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_qlora/llama3_lora_sft_gptq.yaml b/examples/train_qlora/llama3_lora_sft_gptq.yaml
index 6c45e845..729d8628 100644
--- a/examples/train_qlora/llama3_lora_sft_gptq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_gptq.yaml
@@ -16,6 +16,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -23,6 +24,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/examples/train_qlora/llama3_lora_sft_otfq.yaml b/examples/train_qlora/llama3_lora_sft_otfq.yaml
index cfe930e4..1a157afe 100644
--- a/examples/train_qlora/llama3_lora_sft_otfq.yaml
+++ b/examples/train_qlora/llama3_lora_sft_otfq.yaml
@@ -1,7 +1,7 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
-quantization_bit: 4
-quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
+quantization_bit: 4 # choices: [8 (bnb/hqq/eetq), 4 (bnb/hqq), 3 (hqq), 2 (hqq)]
+quantization_method: bnb # choices: [bnb, hqq, eetq]
trust_remote_code: true
### method
@@ -18,6 +18,7 @@ cutoff_len: 2048
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
+dataloader_num_workers: 4
### output
output_dir: saves/llama3-8b/lora/sft
@@ -25,6 +26,8 @@ logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
+save_only_model: false
+report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
### train
per_device_train_batch_size: 1
diff --git a/scripts/eval_bleu_rouge.py b/scripts/eval_bleu_rouge.py
index 735209ce..22e370bc 100644
--- a/scripts/eval_bleu_rouge.py
+++ b/scripts/eval_bleu_rouge.py
@@ -21,9 +21,9 @@ from datasets import load_dataset
try:
- import jieba
- from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
- from rouge_chinese import Rouge
+ import jieba # type: ignore
+ from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
+ from rouge_chinese import Rouge # type: ignore
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
@@ -52,6 +52,7 @@ def compute_metrics(sample):
metric_result = {}
for k, v in result.items():
metric_result[k] = round(v["f"] * 100, 4)
+
metric_result["bleu-4"] = round(bleu_score * 100, 4)
return metric_result
diff --git a/scripts/qwen_omni_merge.py b/scripts/qwen_omni_merge.py
index 3c75d77a..449b17b0 100644
--- a/scripts/qwen_omni_merge.py
+++ b/scripts/qwen_omni_merge.py
@@ -1,7 +1,4 @@
-# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
-#
-# This code is based on the HuggingFace's PEFT library.
-# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
+# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
import shutil
import fire
from peft import PeftModel
-from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration
+from transformers import AutoModel, AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration # type: ignore
def merge_lora(
@@ -41,20 +39,14 @@ def merge_lora(
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model, tokenizer, and processor
- model = AutoModel.from_pretrained(base_model_path)
- tokenizer = AutoTokenizer.from_pretrained(base_model_path)
-
- try:
- processor = AutoProcessor.from_pretrained(base_model_path)
- except Exception:
- print("Processor configuration not found, skipping processor load.")
- processor = None
-
- print("Successfully loaded the original model, tokenizer, and processor (if available).")
+ model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
+ processor = AutoProcessor.from_pretrained(base_model_path)
+ print("Successfully loaded the original model and tokenizer.")
# 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
+
base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.")
@@ -71,11 +63,8 @@ def merge_lora(
# 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path)
- tokenizer.save_pretrained(save_path)
- if processor is not None:
- processor.save_pretrained(save_path)
-
- print(f"Merged model and configuration saved to {save_path}.")
+ processor.save_pretrained(save_path)
+ print(f"Merged model and tokenizer saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
@@ -89,7 +78,7 @@ def merge_lora(
def save_full_model(
saved_thinker_path: str,
base_model_path: str,
- save_path: str,
+ save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt",
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
@@ -99,26 +88,23 @@ def save_full_model(
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
- save_path (str): Directory where the final complete model will be saved.
+ save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
- # Load the thinker module
- thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu")
- # Load the original model
- base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu")
- # Replace the thinker module in the original model
+ # 1. Load the saved thinker module and the original model
+ thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
+ saved_thinker_path, torch_dtype="auto", device_map="cpu"
+ )
+ base_model = AutoModel.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
base_model.thinker = thinker
- # Load the processor and tokenizer
- processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
- tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
-
- # Save the complete model along with its configurations
+ # 2. Save the complete model along with its tokenizer and processor configuration
+ processor = AutoProcessor.from_pretrained(base_model_path)
base_model.save_pretrained(save_path)
- tokenizer.save_pretrained(save_path)
processor.save_pretrained(save_path)
- print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.")
+ print(f"Merged model and tokenizer saved to {save_path}.")
+ # 3. Copy the extra file from the base model directory to the save_path
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py
index ce17adfb..53391eec 100644
--- a/scripts/vllm_infer.py
+++ b/scripts/vllm_infer.py
@@ -20,7 +20,7 @@ from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
-from llamafactory.extras.misc import check_version, get_device_count
+from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -56,7 +56,6 @@ def vllm_infer(
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
- check_version("vllm>=0.4.3,<=0.8.2")
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
diff --git a/setup.py b/setup.py
index 6cfe66a0..e2d92fd1 100644
--- a/setup.py
+++ b/setup.py
@@ -45,7 +45,7 @@ extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
- "deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
+ "deepspeed": ["deepspeed>=0.10.0,<=0.16.5"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
@@ -53,7 +53,7 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
- "vllm": ["vllm>=0.4.3,<=0.8.2"],
+ "vllm": ["vllm>=0.4.3,<=0.8.3"],
"sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py
index b79d22b4..8b64f9a0 100644
--- a/src/llamafactory/data/collator.py
+++ b/src/llamafactory/data/collator.py
@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
-from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available
@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_, seq_len = attention_mask_with_indices.size()
-
- # Move to compute device if the source is CPU.
- source_device = attention_mask_with_indices.device
- compute_device = get_current_device() if source_device.type == "cpu" else source_device
- if compute_device != source_device:
- attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
-
min_dtype = torch.finfo(dtype).min
- zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
+ zero_tensor = torch.tensor(0, dtype=dtype)
# Create a non-padding mask.
- non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
+ non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
- tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
- attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
+ tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
+ attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
-
- # Move back to original device if needed.
- if compute_device != source_device:
- attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index c5133539..887dc08f 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -493,8 +493,8 @@ class Llama4Plugin(BasePlugin):
messages = deepcopy(messages)
for message in messages:
content = message["content"]
- placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
+ placeholder_count = content.count(IMAGE_PLACEHOLDER)
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
@@ -507,6 +507,8 @@ class Llama4Plugin(BasePlugin):
new_content.append(tokens_for_this_image)
content = "".join(new_content)
+ else:
+ content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
message["content"] = content
diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py
index 2e38b8a5..26e50d93 100644
--- a/src/llamafactory/data/processor/supervised.py
+++ b/src/llamafactory/data/processor/supervised.py
@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
- packed_input_ids, packed_attention_masks, packed_labels = [], [], []
- packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
+ packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
+ packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
+ packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
- packed_position_ids += list(range(len(batch_input_ids[index])))
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
+ packed_position_ids += [0] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
- packed_position_ids += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
+ model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
- model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 8732685e..6a84cbf9 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -1370,7 +1370,7 @@ register_template(
slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
- default_system="You are a helpful assistant.",
+ default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"],
)
diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py
index a1829145..42a679c1 100644
--- a/src/llamafactory/extras/constants.py
+++ b/src/llamafactory/extras/constants.py
@@ -14,7 +14,7 @@
import os
from collections import OrderedDict, defaultdict
-from enum import Enum
+from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
@@ -115,6 +115,19 @@ class DownloadSource(str, Enum):
OPENMIND = "om"
+@unique
+class QuantizationMethod(str, Enum):
+ r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
+
+ BNB = "bnb"
+ GPTQ = "gptq"
+ AWQ = "awq"
+ AQLM = "aqlm"
+ QUANTO = "quanto"
+ EETQ = "eetq"
+ HQQ = "hqq"
+
+
class RopeScaling(str, Enum):
LINEAR = "linear"
DYNAMIC = "dynamic"
diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py
index 3a66b2c0..96dfb391 100644
--- a/src/llamafactory/hparams/data_args.py
+++ b/src/llamafactory/hparams/data_args.py
@@ -160,5 +160,11 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
+ if self.neat_packing:
+ self.packing = True
+
+ if self.packing:
+ self.cutoff_len -= 1 # avoid pad_to_multiple_of, needs improve
+
def to_dict(self) -> dict[str, Any]:
return asdict(self)
diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py
index a9f61289..f86291b0 100644
--- a/src/llamafactory/hparams/model_args.py
+++ b/src/llamafactory/hparams/model_args.py
@@ -23,7 +23,7 @@ import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
-from ..extras.constants import AttentionFunction, EngineName, RopeScaling
+from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
@dataclass
@@ -184,8 +184,8 @@ class BaseModelArguments:
class QuantizationArguments:
r"""Arguments pertaining to the quantization method."""
- quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
- default="bitsandbytes",
+ quantization_method: QuantizationMethod = field(
+ default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py
index f23ba89f..1b9d7153 100644
--- a/src/llamafactory/hparams/parser.py
+++ b/src/llamafactory/hparams/parser.py
@@ -135,7 +135,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
- check_version("vllm>=0.4.3,<=0.8.2")
+ check_version("vllm>=0.4.3,<=0.8.3")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
- if data_args.neat_packing and not data_args.packing:
- logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
- data_args.packing = True
-
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py
index e370abfc..598a1a05 100644
--- a/src/llamafactory/model/loader.py
+++ b/src/llamafactory/model/loader.py
@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
- logger.debug(f"Processor was not found: {e}.")
+ logger.debug(f"Failed to load processor: {e}.")
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if processor is not None and "Processor" not in processor.__class__.__name__:
+ logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
processor = None
return {"tokenizer": tokenizer, "processor": processor}
diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py
index e33fed86..ffbf5825 100644
--- a/src/llamafactory/model/model_utils/quantization.py
+++ b/src/llamafactory/model/model_utils/quantization.py
@@ -18,7 +18,6 @@
import os
import random
-from enum import Enum, unique
from typing import TYPE_CHECKING, Any
import torch
@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ...extras import logging
-from ...extras.constants import FILEEXT2TYPE
+from ...extras.constants import FILEEXT2TYPE, QuantizationMethod
from ...extras.misc import check_version, get_current_device
@@ -41,19 +40,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
-@unique
-class QuantizationMethod(str, Enum):
- r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
-
- BITS_AND_BYTES = "bitsandbytes"
- GPTQ = "gptq"
- AWQ = "awq"
- AQLM = "aqlm"
- QUANTO = "quanto"
- EETQ = "eetq"
- HQQ = "hqq"
-
-
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset):
@@ -145,7 +131,7 @@ def configure_quantization(
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly
- if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
+ if model_args.quantization_method == QuantizationMethod.BNB:
if model_args.quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
@@ -173,7 +159,7 @@ def configure_quantization(
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
- elif model_args.quantization_method == QuantizationMethod.HQQ.value:
+ elif model_args.quantization_method == QuantizationMethod.HQQ:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
@@ -185,7 +171,7 @@ def configure_quantization(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
- elif model_args.quantization_method == QuantizationMethod.EETQ.value:
+ elif model_args.quantization_method == QuantizationMethod.EETQ:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py
index 422a702e..d06c4c66 100644
--- a/src/llamafactory/train/dpo/workflow.py
+++ b/src/llamafactory/train/dpo/workflow.py
@@ -91,7 +91,13 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
+ keys = ["loss", "rewards/accuracies"]
+ if isinstance(dataset_module["eval_dataset"], dict):
+ keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
+ else:
+ keys += ["eval_loss"]
+
+ plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py
index 45f82671..3a9a6a76 100644
--- a/src/llamafactory/train/kto/workflow.py
+++ b/src/llamafactory/train/kto/workflow.py
@@ -82,7 +82,13 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
+ keys = ["loss", "rewards/chosen"]
+ if isinstance(dataset_module["eval_dataset"], dict):
+ keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
+ else:
+ keys += ["eval_loss"]
+
+ plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py
index 3c04595a..da5583e2 100644
--- a/src/llamafactory/train/pt/workflow.py
+++ b/src/llamafactory/train/pt/workflow.py
@@ -66,7 +66,13 @@ def run_pt(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
+ keys = ["loss"]
+ if isinstance(dataset_module["eval_dataset"], dict):
+ keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
+ else:
+ keys += ["eval_loss"]
+
+ plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py
index cb693187..81f555dd 100644
--- a/src/llamafactory/train/rm/workflow.py
+++ b/src/llamafactory/train/rm/workflow.py
@@ -74,7 +74,15 @@ def run_rm(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
+ keys = ["loss"]
+ if isinstance(dataset_module["eval_dataset"], dict):
+ keys += sum(
+ [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
+ )
+ else:
+ keys += ["eval_loss", "eval_accuracy"]
+
+ plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py
index 37006707..b80cdfc4 100644
--- a/src/llamafactory/train/sft/workflow.py
+++ b/src/llamafactory/train/sft/workflow.py
@@ -110,7 +110,15 @@ def run_sft(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
+ keys = ["loss"]
+ if isinstance(dataset_module["eval_dataset"], dict):
+ keys += sum(
+ [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
+ )
+ else:
+ keys += ["eval_loss", "eval_accuracy"]
+
+ plot_loss(training_args.output_dir, keys=keys)
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py
index f3616455..d6df1746 100644
--- a/src/llamafactory/webui/components/top.py
+++ b/src/llamafactory/webui/components/top.py
@@ -42,7 +42,7 @@ def create_top() -> dict[str, "Component"]:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
- quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
+ quantization_method = gr.Dropdown(choices=["bnb", "hqq", "eetq"], value="bnb")
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")
diff --git a/tests/data/processor/test_feedback.py b/tests/data/processor/test_feedback.py
index 27c3676c..355e7fe0 100644
--- a/tests/data/processor/test_feedback.py
+++ b/tests/data/processor/test_feedback.py
@@ -25,10 +25,10 @@ from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "kto",
"do_train": True,
"finetuning_type": "full",
@@ -45,7 +45,7 @@ TRAIN_ARGS = {
@pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="kto_en_demo", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
diff --git a/tests/data/processor/test_pairwise.py b/tests/data/processor/test_pairwise.py
index 569a55ab..1040ba82 100644
--- a/tests/data/processor/test_pairwise.py
+++ b/tests/data/processor/test_pairwise.py
@@ -25,10 +25,10 @@ from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "rm",
"do_train": True,
"finetuning_type": "full",
@@ -54,7 +54,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
@pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="dpo_en_demo", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
diff --git a/tests/data/processor/test_supervised.py b/tests/data/processor/test_supervised.py
index e2171721..6eaa34d3 100644
--- a/tests/data/processor/test_supervised.py
+++ b/tests/data/processor/test_supervised.py
@@ -25,12 +25,12 @@ from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
@@ -45,7 +45,7 @@ TRAIN_ARGS = {
@pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(TINY_DATA, split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
@@ -66,7 +66,7 @@ def test_supervised_multi_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
"train_dataset"
]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
@@ -79,7 +79,7 @@ def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_dataset_module(
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", train_on_prompt=True, **TRAIN_ARGS
)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
@@ -93,7 +93,7 @@ def test_supervised_mask_history(num_samples: int):
train_dataset = load_dataset_module(
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", mask_history=True, **TRAIN_ARGS
)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
diff --git a/tests/data/processor/test_unsupervised.py b/tests/data/processor/test_unsupervised.py
index 4b0a97d3..947b2e39 100644
--- a/tests/data/processor/test_unsupervised.py
+++ b/tests/data/processor/test_unsupervised.py
@@ -24,12 +24,12 @@ from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "ppo",
"do_train": True,
"finetuning_type": "full",
@@ -48,7 +48,7 @@ TRAIN_ARGS = {
@pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py
index 23a045ae..a263c0f8 100644
--- a/tests/data/test_collator.py
+++ b/tests/data/test_collator.py
@@ -24,11 +24,11 @@ from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
def test_base_collator():
- model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
+ model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py
index fc2d2a91..b45bdaad 100644
--- a/tests/data/test_loader.py
+++ b/tests/data/test_loader.py
@@ -19,12 +19,12 @@ from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py
index 787a685c..5dfbe9e4 100644
--- a/tests/data/test_mm_plugin.py
+++ b/tests/data/test_mm_plugin.py
@@ -20,7 +20,6 @@ import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
-from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -35,7 +34,8 @@ if TYPE_CHECKING:
HF_TOKEN = os.getenv("HF_TOKEN")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
MM_MESSAGES = [
{"role": "user", "content": "What is in this image?"},
@@ -130,13 +130,13 @@ def _check_plugin(
def test_base_plugin():
- tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
+ tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)
-@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
+@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_gemma3_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
@@ -157,6 +157,27 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
+@pytest.mark.xfail(reason="Unknown error.")
+def test_llama4_plugin():
+ tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
+ processor = tokenizer_module["processor"]
+ llama4_plugin = get_mm_plugin(name="llama4", image_token="<|image|>")
+ check_inputs = {"plugin": llama4_plugin, **tokenizer_module}
+ mm_inputs = _get_mm_inputs(tokenizer_module["processor"])
+ image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
+ num_patches_per_chunk = int(
+ (image_height // processor.patch_size) * (image_width // processor.patch_size) // processor.downsample_ratio
+ )
+ aspect_ratios = mm_inputs.pop("aspect_ratios")
+ tokens_for_this_image = processor._prompt_split_image(aspect_ratios[0], num_patches_per_chunk)
+ check_inputs["expected_mm_messages"] = [
+ {key: value.replace("", tokens_for_this_image) for key, value in message.items()}
+ for message in MM_MESSAGES
+ ]
+ check_inputs["expected_mm_inputs"] = mm_inputs
+ _check_plugin(**check_inputs)
+
+
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
diff --git a/tests/data/test_template.py b/tests/data/test_template.py
index 0f8bc6cc..9ec25392 100644
--- a/tests/data/test_template.py
+++ b/tests/data/test_template.py
@@ -29,7 +29,8 @@ if TYPE_CHECKING:
HF_TOKEN = os.getenv("HF_TOKEN")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
MESSAGES = [
{"role": "user", "content": "How are you"},
@@ -75,7 +76,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
@@ -90,7 +91,7 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
@@ -111,8 +112,8 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
- ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
+ ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template
@@ -120,7 +121,7 @@ def test_jinja_template(use_fast: bool):
def test_ollama_modelfile():
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert template.get_ollama_modelfile(tokenizer) == (
"# ollama modelfile auto-generated by llamafactory\n\n"
@@ -137,7 +138,7 @@ def test_ollama_modelfile():
def test_get_stop_token_ids():
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@@ -152,7 +153,7 @@ def test_gemma_template(use_fast: bool):
"model\n"
)
answer_str = "很高兴认识你!\n"
- _check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, use_fast)
+ _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@@ -168,7 +169,20 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
-@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
+@pytest.mark.parametrize(
+ "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
+)
+def test_llama4_template(use_fast: bool):
+ prompt_str = (
+ "<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
+ "<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
+ "<|header_start|>user<|header_end|>\n\n你好<|eot|>"
+ "<|header_start|>assistant<|header_end|>\n\n"
+ )
+ answer_str = "很高兴认识你!<|eot|>"
+ _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
+
+
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
)
@@ -183,35 +197,21 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
-@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") # TODO: why it is gated?
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen_template(use_fast: bool):
prompt_str = (
- "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>\n"
- _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
+ _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
-@pytest.mark.parametrize("use_fast", [True, False])
-@pytest.mark.xfail(reason="Yi tokenizer is broken.")
-def test_yi_template(use_fast: bool):
- prompt_str = (
- "<|im_start|>user\nHow are you<|im_end|>\n"
- "<|im_start|>assistant\nI am fine!<|im_end|>\n"
- "<|im_start|>user\n你好<|im_end|>\n"
- "<|im_start|>assistant\n"
- )
- answer_str = "很高兴认识你!<|im_end|>\n"
- _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str, use_fast)
-
-
-def test_parse_template():
- tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, token=HF_TOKEN)
+def test_parse_llama3_template():
+ tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
@@ -223,12 +223,11 @@ def test_parse_template():
assert template.default_system == ""
-@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_parse_qwen_template():
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct", token=HF_TOKEN)
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == []
- assert template.default_system == "You are a helpful assistant."
+ assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
diff --git a/tests/e2e/test_chat.py b/tests/e2e/test_chat.py
index 98818f27..3221b6de 100644
--- a/tests/e2e/test_chat.py
+++ b/tests/e2e/test_chat.py
@@ -17,10 +17,10 @@ import os
from llamafactory.chat import ChatModel
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
diff --git a/tests/e2e/test_train.py b/tests/e2e/test_train.py
index f16b3522..48f6f79d 100644
--- a/tests/e2e/test_train.py
+++ b/tests/e2e/test_train.py
@@ -21,12 +21,12 @@ from llamafactory.train.tuner import export_model, run_exp
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"do_train": True,
"finetuning_type": "lora",
"dataset_dir": "REMOTE:" + DEMO_DATA,
@@ -35,10 +35,11 @@ TRAIN_ARGS = {
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
+ "report_to": "none",
}
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py
index a3deda29..0063630a 100644
--- a/tests/model/model_utils/test_attention.py
+++ b/tests/model/model_utils/test_attention.py
@@ -21,10 +21,10 @@ from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_infer_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"template": "llama3",
}
diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py
index 36c20434..2402e6fb 100644
--- a/tests/model/model_utils/test_checkpointing.py
+++ b/tests/model/model_utils/test_checkpointing.py
@@ -21,10 +21,10 @@ from llamafactory.extras.misc import get_current_device
from llamafactory.train.test_utils import load_train_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
diff --git a/tests/model/test_base.py b/tests/model/test_base.py
index 9e8c5048..e06b4467 100644
--- a/tests/model/test_base.py
+++ b/tests/model/test_base.py
@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"template": "llama3",
"infer_dtype": "float16",
}
@@ -37,7 +37,7 @@ def fix_valuehead_cpu_loading():
def test_base():
model = load_infer_model(**INFER_ARGS)
- ref_model = load_reference_model(TINY_LLAMA)
+ ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model)
diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py
index 97200852..b82ec88d 100644
--- a/tests/model/test_freeze.py
+++ b/tests/model/test_freeze.py
@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"finetuning_type": "freeze",
"template": "llama3",
"infer_dtype": "float16",
diff --git a/tests/model/test_full.py b/tests/model/test_full.py
index 8aff2223..9058b6ac 100644
--- a/tests/model/test_full.py
+++ b/tests/model/test_full.py
@@ -19,10 +19,10 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
@@ -36,7 +36,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"finetuning_type": "full",
"template": "llama3",
"infer_dtype": "float16",
diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py
index 1cda7bb7..3d394c33 100644
--- a/tests/model/test_lora.py
+++ b/tests/model/test_lora.py
@@ -27,14 +27,14 @@ from llamafactory.train.test_utils import (
)
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
@@ -48,7 +48,7 @@ TRAIN_ARGS = {
}
INFER_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
@@ -81,13 +81,13 @@ def test_lora_train_extra_modules():
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
- ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
+ ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
- ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
+ ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
)
@@ -105,5 +105,5 @@ def test_lora_train_valuehead():
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
- ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
+ ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
compare_model(model, ref_model)
diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py
index 08863a07..3b6101f8 100644
--- a/tests/model/test_pissa.py
+++ b/tests/model/test_pissa.py
@@ -19,12 +19,12 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py
index c520bb3a..66a33af8 100644
--- a/tests/train/test_sft_trainer.py
+++ b/tests/train/test_sft_trainer.py
@@ -27,10 +27,10 @@ from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
-TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
+TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
- "model_name_or_path": TINY_LLAMA,
+ "model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
@@ -41,6 +41,7 @@ TRAIN_ARGS = {
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
+ "report_to": "none",
}
diff --git a/tests/version.txt b/tests/version.txt
index 399e7891..4219b968 100644
--- a/tests/version.txt
+++ b/tests/version.txt
@@ -1,2 +1,2 @@
# change if test fails
-0.9.3.101
+0.9.3.102