5 Commits

Author SHA1 Message Date
Yaowei Zheng
7af909522a [version] release v0.9.5 (#10532) 2026-05-30 23:57:09 +08:00
xvxuopop
e016d2480e [fix] Fix NPU FusedMoE and RMSNorm (#10512) 2026-05-30 21:42:54 +08:00
jiaqiw09
7d719182c9 [model] fix non-packing batch (bsz>1) for Qwen3.5 with flash attention (#10529) 2026-05-30 21:41:41 +08:00
jiaqiw09
01398eb18d [v1] fix padding free with sp (#10513) 2026-05-26 23:49:21 +08:00
cxy
8e68764b65 [v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
2026-05-25 20:40:21 +08:00
14 changed files with 394 additions and 78 deletions

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg), [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg), [Lab4AI](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg), [LLaMA Factory Online](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.jpg) user group.
👋 Join our [WeChat](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg) and [NPU](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg) user groups.
\[ English | [中文](README_zh.md) \]
@@ -52,14 +50,11 @@ Start local training:
Start cloud training:
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online**: https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **Alaya NeW (cloud GPU deal)**: https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
Read technical notes:
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/en/latest/
- **Documentation (AMD GPU)**: https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/fine_tune/llama_factory_llama3.html
- **Official Blog**: https://blog.llamafactory.net/en/
- **Official Course**: https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
@@ -78,7 +73,6 @@ Read technical notes:
- [Data Preparation](#data-preparation)
- [Quickstart](#quickstart)
- [Fine-Tuning with LLaMA Board GUI](#fine-tuning-with-llama-board-gui-powered-by-gradio)
- [LLaMA Factory Online](#llama-factory-online)
- [Build Docker](#build-docker)
- [Deploy with OpenAI-style API and vLLM](#deploy-with-openai-style-api-and-vllm)
- [Download from ModelScope Hub](#download-from-modelscope-hub)
@@ -117,15 +111,11 @@ Read technical notes:
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: Fine-tuning 1000 Billion models with 2 4090-GPU + CPU](https://blog.llamafactory.net/en/posts/ktransformers/) (English)
- 💡 [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
- [Fine-tune a mental health LLM using LLaMA-Factory](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune GPT-OSS for Role-Playing using LLaMA-Factory](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory) (Chinese)
- [A One-Stop Code-Free Model Reinforcement Learning and Deployment Platform based on LLaMA-Factory and EasyR1](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/) (Chinese)
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
<details><summary>All Blogs</summary>
- [Fine-tune Llama3.1-70B for Medical Diagnosis using LLaMA-Factory](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory) (Chinese)
- [Fine-tune Qwen2.5-VL for Autonomous Driving using LLaMA-Factory](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory) (Chinese)
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
- [A One-Stop Code-Free Model Fine-Tuning \& Deployment Platform based on SageMaker and LLaMA-Factory](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) (Chinese)
- [LLaMA Factory Multi-Modal Fine-Tuning Practice: Fine-Tuning Qwen2-VL for Personal Tourist Guide](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) (Chinese)
@@ -661,10 +651,6 @@ See [examples/README.md](examples/README.md) for advanced usage (including distr
llamafactory-cli webui
```
### LLaMA Factory Online
Read our [documentation](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory).
### Build Docker
For CUDA users:

View File

@@ -15,8 +15,6 @@
[![Open in Colab](assets/thirdparty/colab.svg)](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)
[![Open in DSW](assets/thirdparty/dsw.svg)](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
[![Open in Lab4ai](assets/thirdparty/lab4ai.svg)](https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory)
[![Open in Online](assets/thirdparty/online.svg)](https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory)
[![Open in Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
[![Open in Studios](https://img.shields.io/badge/ModelScope-Open%20in%20Studios-blue)](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
[![Open in Novita](https://img.shields.io/badge/Novita-Deploy%20Template-blue)](https://novita.ai/templates-library/105981?sharer=88115474-394e-4bda-968e-b88e123d0c47)
@@ -38,7 +36,7 @@
</div>
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)、[大模型实验室群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/lab4ai.jpg) 或 [LLaMA Factory Online 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/online.png)
👋 加入我们的[微信群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/main.jpg)[NPU 用户群](https://github.com/hiyouga/llamafactory-community/blob/main/wechat/npu.jpg)。
\[ [English](README.md) | 中文 \]
@@ -52,8 +50,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
开始云端训练:
- **Colab免费**https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
- **PAI-DSW免费试用**https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
- **LLaMA Factory Online在线微调**https://www.llamafactory.com.cn/?utm_source=LLaMA-Factory
- **九章智算云(算力优惠活动)**https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory
阅读技术文档:
- **入门教程**https://zhuanlan.zhihu.com/p/695287607
@@ -61,7 +57,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- **框架文档**https://llamafactory.readthedocs.io/zh-cn/latest/
- **框架文档(昇腾 NPU**https://ascend.github.io/docs/sources/llamafactory/
- **官方博客**https://blog.llamafactory.net/
- **官方课程**https://www.lab4ai.cn/course/detail?id=7c13e60f6137474eb40f6fd3983c0f46&utm_source=LLaMA-Factory
> [!NOTE]
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
@@ -80,7 +75,6 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- [数据准备](#数据准备)
- [快速开始](#快速开始)
- [LLaMA Board 可视化微调](#llama-board-可视化微调由-gradio-驱动)
- [LLaMA Factory Online 在线微调](#llama-factory-online-在线微调)
- [构建 Docker](#构建-docker)
- [利用 vLLM 部署 OpenAI API](#利用-vllm-部署-openai-api)
- [从魔搭社区下载](#从魔搭社区下载)
@@ -119,15 +113,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
- 💡 [KTransformers Fine-Tuning × LLaMA Factory: 用2张4090级的GPU+CPU 微调 1000B规模的超大模型](https://swcil84qspu.feishu.cn/wiki/Z1sSwb2poijybxkyPEkcDG6enVc) (中文)
- 💡 [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
- [使用 LLaMA-Factory 微调心理健康大模型](https://www.lab4ai.cn/project/detail?id=25cce32ec131497b9e06a93336a0817f&type=project&utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 构建 GPT-OSS 角色扮演模型](https://docs.llamafactory.com.cn/docs/documents/best-practice/gptroleplay/?utm_source=LLaMA-Factory)(中文)
- [基于 LLaMA-Factory 和 EasyR1 打造一站式无代码大模型强化学习和部署平台 LLM Model Hub](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)(中文)
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
<details><summary>全部博客</summary>
- [使用 LLaMA-Factory 微调 Llama3.1-70B 医学诊断模型](https://docs.alayanew.com/docs/documents/bestPractice/bigModel/llama70B/?utm_source=LLaMA-Factory)(中文)
- [使用 LLaMA-Factory 微调 Qwen2.5-VL 实现自动驾驶场景微调](https://docs.alayanew.com/docs/documents/useGuide/LLaMAFactory/mutiple/?utm_source=LLaMA-Factory)(中文)
- [LLaMA Factory微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
@@ -662,10 +652,6 @@ llamafactory-cli export examples/merge_lora/qwen3_lora_sft.yaml
llamafactory-cli webui
```
### LLaMA Factory Online 在线微调
详情阅读该[文档](https://docs.llamafactory.com.cn/docs/documents/quickstart/getstarted/?utm_source=LLaMA-Factory)。
### 构建 Docker
CUDA 用户:

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: fa2
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10

View File

@@ -19,7 +19,7 @@
from collections import OrderedDict
VERSION = "0.9.5.dev0"
VERSION = "0.9.5"
def print_env() -> None:

View File

@@ -162,8 +162,14 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
# cu_seqlens for the FLA varlen path is only needed when batch_size == 1:
# packing / neat-packing: always folded into a single sequence (bsz == 1) -> varlen
# non-packing, bsz == 1: single segment, equivalent to a standard single sequence
# non-packing, bsz > 1: not packed, use cu_seqlens=None and standard batched kernels
if position_ids is not None and batch_size == 1:
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
else:
cu_seqlens = None
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},

View File

@@ -228,6 +228,30 @@ class NpuMoeFused:
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
@staticmethod
def npu_moe_experts_v5_forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""Forward pass for Transformers v5+ MoE experts using NPU fused operations.
Transformers v5 stores expert weights in F.linear layout:
gate_up_proj: [num_experts, 2 * intermediate_dim, hidden_dim]
down_proj: [num_experts, hidden_dim, intermediate_dim]
The NPU grouped matmul path expects matmul layout, so both weights are transposed.
"""
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
hidden_states, top_k_index.to(torch.int32)
)
tokens_per_expert = torch.histc(top_k_index.float(), bins=self.num_experts, min=0, max=self.num_experts).long()
gate_up_proj = self.gate_up_proj.transpose(1, 2)
down_proj = self.down_proj.transpose(1, 2)
intermediate_hidden_states = GmmFunction.apply(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = GmmFunction.apply(intermediate_activations, down_proj, tokens_per_expert)
return torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=top_k_weights)
class Qwen3NpuMoeFused:
"""Container for Qwen3 NPU fused MoE forward functions."""
@@ -283,16 +307,30 @@ class Qwen3NpuMoeFused:
# moe patch config mapping
kernel_moe_mapping = {
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
if is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": NpuMoeFused.npu_moe_experts_v5_forward,
},
}
}
if not is_transformers_version_greater_than("5.0.0"):
kernel_moe_mapping["Qwen3MoeForCausalLM"] = {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward
else:
kernel_moe_mapping = {
"Qwen3MoeForCausalLM": {
"Qwen3MoeSparseMoeBlock": Qwen3NpuMoeFused.qwen3moe_sparse_moe_block_forward,
},
"Qwen3VLMoeForConditionalGeneration": {
"Qwen3VLMoeTextExperts": NpuMoeFused.npu_moe_experts_forward,
"Qwen3VLMoeTextSparseMoeBlock": NpuMoeFused.npu_moe_sparse_block_forward,
},
}

View File

@@ -51,22 +51,17 @@ def _should_use_residual_rmsnorm(module):
bool: ``True`` if the module uses residual parameterization, ``False`` otherwise.
.. note::
This detection ensures compatibility with future model versions (e.g., Qwen3.6, Qwen4.0)
without hardcoding version numbers. Two methods are used: weight value inspection
(most reliable) and class name pattern matching (backward compatibility).
This must follow the module's forward semantics. Do not infer it from trained
weight values because standard RMSNorm weights can also be close to zero.
"""
if hasattr(module, "weight") and module.weight is not None:
weight_mean = module.weight.data.mean().item()
if abs(weight_mean) < 0.3:
return True
residual_rmsnorm_classes = {
"Qwen3_5RMSNorm",
"Qwen3_5MoeRMSNorm",
"Qwen3NextRMSNorm",
}
class_name = module.__class__.__name__
residual_patterns = ["Qwen3_5", "Qwen3_6", "Qwen4"]
for pattern in residual_patterns:
if pattern in class_name:
return True
return False
return class_name in residual_rmsnorm_classes
def npu_rms_norm_forward(self, hidden_states):
@@ -82,7 +77,7 @@ def npu_rms_norm_forward(self, hidden_states):
_eps = getattr(self, "variance_epsilon", None) or getattr(self, "eps", 1e-6)
if hasattr(self, "weight") and self.weight is not None:
if _should_use_residual_rmsnorm(self):
if getattr(self, "_npu_use_residual_rmsnorm", False):
effective_weight = 1.0 + self.weight.float()
else:
effective_weight = self.weight.float()
@@ -162,6 +157,7 @@ class NpuRMSNormKernel(BaseKernel):
if "Gated" in module.__class__.__name__:
module.forward = types.MethodType(npu_gated_rms_norm_forward, module)
else:
module._npu_use_residual_rmsnorm = _should_use_residual_rmsnorm(module)
module.forward = types.MethodType(npu_rms_norm_forward, module)
return model

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
if position_ids is not None:
global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else:
attention_mask = attention_mask.to(torch.int64)
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,

View File

@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch

View File

@@ -13,22 +13,46 @@
# limitations under the License.
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
"""
hf_set_seed(seed)
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool:

View File

@@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
from llamafactory.v1.plugins.trainer_plugins.batching import (
BatchingPlugin,
_get_dynamic_micro_batch_sizes,
_get_dynamic_padding_free_micro_batch_sizes,
)
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
@@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes():
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
@@ -196,3 +201,168 @@ def test_normal_batching():
batch = next(iter(batch_generator))
assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10)
def test_dynamic_padding_free():
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
# input_ids breakdown:
# sample 0: [0,1,2] (length=3)
# sample 1: [10,11,12,13] (length=4)
# sample 2: [20,21,22,23,24,25] (length=6)
# sample 3: [30,31] (length=2)
# sample 4: [40-47] (length=8)
# sample 5: [50-58] (length=9)
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1 # num_micro_batch=1
packed_batch = batch[0]
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
assert packed_batch["input_ids"].shape == (1, 15)
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
assert packed_batch["input_ids"].tolist() == [
[
0,
1,
2, # Sample 0
10,
11,
12,
13, # Sample 1
20,
21,
22,
23,
24,
25, # Sample 2
30,
31,
] # Sample 3
]
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
assert packed_batch["labels"].tolist() == [
[
0,
1,
2, # Sample 0
IGNORE_INDEX,
11,
12,
13, # Sample 1
IGNORE_INDEX,
21,
22,
23,
24,
25, # Sample 2
IGNORE_INDEX,
31,
] # Sample 3
]
# Verify attention_mask
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
# Verify position_ids
assert packed_batch["position_ids"].tolist() == [
[
0,
1,
2, # Sample 0
0,
1,
2,
3, # Sample 1
0,
1,
2,
3,
4,
5, # Sample 2
0,
1,
] # Sample 3
]
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
assert len(buffer) == 2
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
Data provider yields one sample of length 6 per iteration.
_fill_buffer keeps restarting iterator until next sample exceeds budget.
Budget = 2 * 10 = 20 tokens.
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
So buffer will have 4 samples after fill_buffer.
"""
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_padding_free"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
batch_generator._fill_buffer()
# Buffer after restarts:
# 3 samples can fit (18 tokens)
# 4th sample is kept in buffer for next batch
# => num_iters = 4
assert data_provider.num_iters == 4
assert _get_dynamic_padding_free_micro_batch_sizes(
batch_generator._buffer.samples, batch_generator._batch_info
) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic_padding_free returns [micro_batch_0]
# 3 samples packed into shape [1, 18]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 18)
assert len(batch_generator._buffer) == 1