mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 20:58:54 +08:00
Compare commits
5 Commits
16ff5a23cb
...
v0.9.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7af909522a | ||
|
|
e016d2480e | ||
|
|
7d719182c9 | ||
|
|
01398eb18d | ||
|
|
8e68764b65 |
16
README.md
16
README.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
|
||||
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
@@ -52,14 +50,11 @@ Start local training:
|
||||
Start cloud training:
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
Read technical notes:
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
|
||||
- **Official Blog**: https://blog.llamafactory.net/en/
|
||||
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
@@ -78,7 +73,6 @@ Read technical notes:
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
|
||||
- [LLaMA Factory Online](#llama-factory-online)
|
||||
- [Build Docker](#build-docker)
|
||||
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
|
||||
- [Download from ModelScope Hub](#download-from-modelscope-hub)
|
||||
@@ -117,15 +111,11 @@ Read technical notes:
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
|
||||
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
|
||||
@@ -661,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online
|
||||
|
||||
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
|
||||
|
||||
### Build Docker
|
||||
|
||||
For CUDA users:
|
||||
|
||||
16
README_zh.md
16
README_zh.md
@@ -15,8 +15,6 @@
|
||||
|
||||
[](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
|
||||
[](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
|
||||
@@ -38,7 +36,7 @@
|
||||
|
||||
</div>
|
||||
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)、[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)。
|
||||
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)和 [NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
开始云端训练:
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **LLaMA Factory Online(在线微调)**:https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
|
||||
- **九章智算云(算力优惠活动)**:https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
|
||||
|
||||
阅读技术文档:
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **官方博客**:https://blog.llamafactory.net/
|
||||
- **官方课程**:https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- [数据准备](#数据准备)
|
||||
- [快速开始](#快速开始)
|
||||
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
|
||||
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
|
||||
- [构建 Docker](#构建-docker)
|
||||
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
|
||||
- [从魔搭社区下载](#从魔搭社区下载)
|
||||
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
|
||||
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
|
||||
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
|
||||
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
@@ -662,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
|
||||
llamafactory-cli webui
|
||||
```
|
||||
|
||||
### LLaMA Factory Online 在线微调
|
||||
|
||||
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
CUDA 用户:
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: dynamic_padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: padding_free
|
||||
flash_attn: fa2
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
VERSION = "0.9.5.dev0"
|
||||
VERSION = "0.9.5"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
|
||||
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
|
||||
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
|
||||
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
|
||||
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
|
||||
if position_ids is not None and batch_size == 1:
|
||||
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
|
||||
@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
model_args, data_args, training_args, sample_args = parsed_args
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
|
||||
|
||||
return tuple(parsed_args)
|
||||
return model_args, data_args, training_args, sample_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -85,6 +85,10 @@ class TrainingArguments:
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
full_determinism: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
|
||||
)
|
||||
resume_from_checkpoint: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
|
||||
|
||||
@@ -228,6 +228,30 @@ class NpuMoeFused:
|
||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||
return routed_out
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_v5_forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
|
||||
|
||||
Transformers v5 stores expert weights in F.linear layout:
|
||||
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
|
||||
down_proj: [num_experts, hidden_dim, intermediate_dim]
|
||||
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
|
||||
"""
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
hidden_states, top_k_index.to(torch.int32)
|
||||
)
|
||||
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
|
||||
|
||||
gate_up_proj = self.gate_up_proj.transpose(1, 2)
|
||||
down_proj = self.down_proj.transpose(1, 2)
|
||||
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
|
||||
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
|
||||
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
|
||||
|
||||
|
||||
# moe patch config mapping
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
if is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForCausalLM": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
"Qwen3_5MoeForConditionalGeneration": {
|
||||
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
|
||||
else:
|
||||
kernel_moe_mapping = {
|
||||
"Qwen3MoeForCausalLM": {
|
||||
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
|
||||
},
|
||||
"Qwen3VLMoeForConditionalGeneration": {
|
||||
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
|
||||
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
|
||||
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
|
||||
|
||||
.. note::
|
||||
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
|
||||
without hardcoding version numbers. Two methods are used: weight value inspection
|
||||
(most reliable) and class name pattern matching (backward compatibility).
|
||||
This must follow the module's forward semantics. Do not infer it from trained
|
||||
weight values because standard RMSNorm weights can also be close to zero.
|
||||
"""
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
weight_mean = module.weight.data.mean().item()
|
||||
if abs(weight_mean) < 0.3:
|
||||
return True
|
||||
residual_rmsnorm_classes = {
|
||||
"Qwen3_5RMSNorm",
|
||||
"Qwen3_5MoeRMSNorm",
|
||||
"Qwen3NextRMSNorm",
|
||||
}
|
||||
|
||||
class_name = module.__class__.__name__
|
||||
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
|
||||
for pattern in residual_patterns:
|
||||
if pattern in class_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
return class_name in residual_rmsnorm_classes
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
|
||||
|
||||
if hasattr(self, "weight") and self.weight is not None:
|
||||
if _should_use_residual_rmsnorm(self):
|
||||
if getattr(self, "_npu_use_residual_rmsnorm", False):
|
||||
effective_weight = 1.0 + self.weight.float()
|
||||
else:
|
||||
effective_weight = self.weight.float()
|
||||
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
if "Gated" in module.__class__.__name__:
|
||||
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
|
||||
else:
|
||||
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
|
||||
module.forward = types.MethodType(npu_rms_norm_forward, module)
|
||||
|
||||
return model
|
||||
|
||||
@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
|
||||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||
# in shape : e.g., [s/p:h:]
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
|
||||
# scatter 2, gather 1
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** -0.5
|
||||
|
||||
if attention_mask is None:
|
||||
if position_ids is not None:
|
||||
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||
else:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
if position_ids is not None:
|
||||
global_position_ids = [
|
||||
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_position_ids, position_ids, group=self.spg)
|
||||
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
|
||||
context_layer = self.attn_fn(
|
||||
q,
|
||||
|
||||
@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
|
||||
return sizes
|
||||
|
||||
|
||||
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
current_tokens = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
|
||||
while index + used < len(samples):
|
||||
sample = samples[index + used]
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
|
||||
if current_tokens + sample_len > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
current_tokens += sample_len
|
||||
used += 1
|
||||
|
||||
if used <= 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||
packed: dict[str, list[Any]] = {}
|
||||
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
|
||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
|
||||
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("compute_length")
|
||||
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
|
||||
def fill_dynamic_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
|
||||
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
|
||||
if packed_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_batch)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -13,22 +13,46 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.helper import is_torch_npu_available
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
def enable_full_determinism(seed: int) -> None:
|
||||
"""Enable full deterministic mode for reproducible distributed training."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed(seed)
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def set_seed(seed: int, full_determinism: bool = False) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
full_determinism: Whether to enable full deterministic mode.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
if full_determinism:
|
||||
enable_full_determinism(seed)
|
||||
else:
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
|
||||
@@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import (
|
||||
BatchingPlugin,
|
||||
_get_dynamic_micro_batch_sizes,
|
||||
_get_dynamic_padding_free_micro_batch_sizes,
|
||||
)
|
||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||
|
||||
@@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes():
|
||||
|
||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching():
|
||||
@@ -196,3 +201,168 @@ def test_normal_batching():
|
||||
batch = next(iter(batch_generator))
|
||||
assert len(batch) == 2
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
def test_dynamic_padding_free():
|
||||
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
|
||||
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
|
||||
# input_ids breakdown:
|
||||
# sample 0: [0,1,2] (length=3)
|
||||
# sample 1: [10,11,12,13] (length=4)
|
||||
# sample 2: [20,21,22,23,24,25] (length=6)
|
||||
# sample 3: [30,31] (length=2)
|
||||
# sample 4: [40-47] (length=8)
|
||||
# sample 5: [50-58] (length=9)
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1 # num_micro_batch=1
|
||||
packed_batch = batch[0]
|
||||
|
||||
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
|
||||
assert packed_batch["input_ids"].shape == (1, 15)
|
||||
|
||||
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
|
||||
assert packed_batch["input_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
30,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
|
||||
assert packed_batch["labels"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
IGNORE_INDEX,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
IGNORE_INDEX,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
IGNORE_INDEX,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify attention_mask
|
||||
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
|
||||
|
||||
# Verify position_ids
|
||||
assert packed_batch["position_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3, # Sample 1
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5, # Sample 2
|
||||
0,
|
||||
1,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
|
||||
assert len(buffer) == 2
|
||||
|
||||
|
||||
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
|
||||
|
||||
Data provider yields one sample of length 6 per iteration.
|
||||
_fill_buffer keeps restarting iterator until next sample exceeds budget.
|
||||
Budget = 2 * 10 = 20 tokens.
|
||||
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
|
||||
So buffer will have 4 samples after fill_buffer.
|
||||
"""
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_padding_free"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Buffer after restarts:
|
||||
# 3 samples can fit (18 tokens)
|
||||
# 4th sample is kept in buffer for next batch
|
||||
# => num_iters = 4
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(
|
||||
batch_generator._buffer.samples, batch_generator._batch_info
|
||||
) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic_padding_free returns [micro_batch_0]
|
||||
# 3 samples packed into shape [1, 18]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 18)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
Reference in New Issue
Block a user