3 Commits

Author SHA1 Message Date
sunyi0505
b5afabe3d2 [v1] support ulysses cp for fsdp2 (#10262) 2026-03-27 16:22:48 +08:00
jiaqiw09
df2e6edb7e [v1] add init on rank0 for fsdp2 (#10264) 2026-03-27 14:54:03 +08:00
Goalina
d02fcd3588 [ci] add nginx cache config for Ascend NPU CI environment (#10323) 2026-03-27 10:04:16 +08:00
17 changed files with 642 additions and 19 deletions

View File

@@ -49,6 +49,12 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v6 uses: actions/checkout@v6
- name: Set nginx-cache for Ascend CI
run: |
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
with: with:

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -0,0 +1,23 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
cp_mode: ulysses
cp_size: 2
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_ulysses_cp
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: ./outputs/test_lora output_dir: ./outputs/test_lora
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: true bf16: true

View File

@@ -0,0 +1,40 @@
model: Qwen/Qwen3-4B
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_rank0
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -71,6 +71,7 @@ class BaseTrainer:
# cached variables # cached variables
self.device = DistributedInterface().current_device self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP) self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
self.model_input_names = self.renderer.processor.model_input_names self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator() self._create_batch_generator()
@@ -114,6 +115,21 @@ class BaseTrainer:
# Callbacks: TrainerState tracks progress across the full run. # Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps) self.state = TrainerState(num_training_steps=self.num_training_steps)
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
if model.config.model_type == "qwen3_5":
raise RuntimeError(
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
)
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None: def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator( self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset, dataset=self.train_dataset,
@@ -172,7 +188,7 @@ class BaseTrainer:
""" """
batch_size, _ = batch["labels"].shape batch_size, _ = batch["labels"].shape
model_inputs = { model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
} }
labels = batch["labels"].to(self.device, non_blocking=True) labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs) outputs: ModelOutput = model(**model_inputs)
@@ -206,6 +222,13 @@ class BaseTrainer:
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
num_micro = len(micro_batches) num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches): for i, micro_batch in enumerate(micro_batches):
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from ..plugins.model_plugins.parallelization.sequence_parallel import (
SequenceParallelLossPlugin,
)
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
else:
loss = self.compute_loss(micro_batch) loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch]) mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size # fsdp uses mean reduction so we need to scale the loss by dp_size
@@ -223,7 +246,24 @@ class BaseTrainer:
# deepspeed: engine.step() already ran inside backward at the sync boundary # deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm() grad_norm = self._deepspeed_engine.get_grad_norm()
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
parameters = self.model.parameters()
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
grad_norm = _get_total_norm(grads)
grad_norm = grad_norm.to(self.device)
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
grad_norm = grad_norm.full_tensor().item()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float # isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]

View File

@@ -140,6 +140,9 @@ class ModelEngine:
**init_kwargs, **init_kwargs,
) )
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
model._init_mode = init_mode
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning") logger.info_rank0("Fine-tuning mode: full tuning")
@@ -147,6 +150,9 @@ class ModelEngine:
else: else:
logger.info_rank0("Inference the original model") logger.info_rank0("Inference the original model")
else: else:
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
from ..plugins.model_plugins.peft import PeftPlugin from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)

View File

@@ -146,6 +146,8 @@ class Renderer:
for sample in samples: for sample in samples:
if "messages" in sample: if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools")) model_input = self.render_messages(sample["messages"], sample.get("tools"))
if "position_ids" not in model_input:
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
elif "chosen_messages" in sample and "rejected_messages" in sample: elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))

View File

@@ -0,0 +1,59 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
import torch.distributed as dist
from torch import Tensor
def all_to_all_tensor(
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
group: Optional[dist.ProcessGroup] = None,
):
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
) -> Tensor:
ctx.group = group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
return (
None,
all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group),
None,
None,
)

View File

@@ -0,0 +1,199 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from functools import partial
import torch
import torch.distributed as dist
import torch.nn.functional as F
import transformers
from ....accelerator.interface import Dim, DistributedInterface
from ....utils import logging
from ....utils.plugin import BasePlugin
from ....utils.types import ModelOutput
from .ulysses import (
UlyssesAttention,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
set_ulysses_sequence_parallel_group,
)
logger = logging.get_logger(__name__)
class SequenceParallelModelPlugin(BasePlugin):
def __call__(self, model, model_args):
return super().__call__(model, model_args)
class SequenceParallelLossPlugin(BasePlugin):
def __call__(self, model, inputs, *args, **kwargs):
return super().__call__(model, inputs, *args, **kwargs)
def new_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
sequence_parallel_size=1,
dropout=0,
deterministic=False,
is_causal=True,
group=None,
mode="ulysses",
attn_fn=None,
target_dtype=None,
**kwargs,
):
if mode == "ulysses":
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
attn_output = dist_attn(
query_states,
key_states,
value_states,
attention_mask,
query_length=query_states.shape[1] * sequence_parallel_size,
deterministic=deterministic,
dropout_p=dropout,
causal=is_causal,
position_ids=kwargs.get("position_ids", None),
target_dtype=target_dtype,
)
else:
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
return attn_output
@SequenceParallelModelPlugin("ulysses").register()
def apply_sequence_parallel(model, model_args):
# Replace _flash_attention_forward with new_flash_attn_forward
module = sys.modules[model.__module__]
cp_size = model_args.get("cp_size", 1)
set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP))
try:
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads
except AttributeError:
num_attention_heads, num_key_value_heads = (
model.config.text_config.num_attention_heads,
model.config.text_config.num_key_value_heads,
)
assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size"
assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, (
"num_key_value_heads must be divisible by cp_size"
)
origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward
new_flash_attention_forward = partial(
new_flash_attn_forward,
group=get_ulysses_sequence_parallel_group(),
mode="ulysses",
attn_fn=origin_attn,
sequence_parallel_size=cp_size,
)
for module_name, module in list(sys.modules.items()):
try:
if (
hasattr(module, "__file__")
and "transformers" in module.__file__
and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward"
):
module._flash_attention_forward = new_flash_attention_forward
logger.info_rank0(
f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel."
)
except (AttributeError, TypeError):
continue
def padding_and_split_data(data, device_mesh=None):
if device_mesh is not None:
cp_size = device_mesh["cp"].size()
cp_rank = device_mesh["cp"].get_local_rank()
cp_group = device_mesh["cp"].get_group()
for k, v in data.items():
if isinstance(v, torch.Tensor) and v.ndim > 1:
data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64)
global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)]
dist.all_gather(global_data_len, data_len, group=cp_group)
max_data_len = max(global_data_len)
pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size
if k == "labels":
pad_value = -100
elif k == "loss_weights":
pad_value = 0.0
else:
pad_value = 0
pad_data = F.pad(v, (0, pad_size), value=pad_value)
data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous()
return data
@SequenceParallelLossPlugin("sequence_parallel_loss").register()
def sequence_parallel_loss(model, model_inputs):
device_mesh = DistributedInterface().get_device_mesh(Dim.CP)
model_inputs = {
k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)
}
model_inputs = padding_and_split_data(model_inputs, device_mesh)
batch_size, _ = model_inputs["labels"].shape
outputs: ModelOutput = model(**model_inputs)
logits = outputs.logits.float()
labels = model_inputs["labels"]
cp_group = get_ulysses_sequence_parallel_group()
cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group)
cp_rank = get_ulysses_sequence_parallel_rank(cp_group)
# use all_gather to collect labels from all sequence parallel processes
global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)]
dist.all_gather(global_labels, labels, group=cp_group)
labels = torch.cat(global_labels, dim=1).contiguous()
shift_labels = labels[..., 1:].view(-1).contiguous()
shift_labels = F.pad(shift_labels, (0, 1), value=-100)
shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous()
# use all_gather to collect loss_weights from all sequence parallel processes
loss_weights = model_inputs["loss_weights"]
global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)]
dist.all_gather(global_loss_weights, loss_weights, group=cp_group)
shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous()
shift_loss_weights = shift_loss_weights[..., 1:].contiguous()
shift_logits = logits.view(shift_labels.size(0), -1).contiguous()
# use all_gather to collect log_probs from all sequence parallel processes
log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
global_log_probs = dist.nn.all_gather(log_probs, group=cp_group)
global_log_probs = torch.cat(global_log_probs, dim=1).contiguous()
log_probs = global_log_probs[..., :-1].contiguous()
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
return loss

View File

@@ -0,0 +1,163 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from .seq_comm import SeqAllToAll4D
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
"""Set ulysses sequence parallel process group."""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""Get ulysses sequence parallel process group."""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
"""Get ulysses sequence parallel world size."""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_world_size(group) if group else 1
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
"""Get ulysses sequence parallel rank."""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_rank(group) if group else 0
class UlyssesAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
attn_type (AttnType): attention type enum
"""
def __init__(
self,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
attn_fn: Optional[callable] = None,
) -> None:
super().__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.attn_fn = attn_fn
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: torch.Tensor,
query_length: int,
dropout_p=0.0,
softmax_scale=None,
position_ids: Optional[torch.Tensor] = None,
causal=True,
deterministic=False,
target_dtype=None,
*args: Any,
) -> Tensor:
"""Forward.
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
attention_mask (Tensor): attention mask for the layer
query_length (int): the length of the query sequence
dropout_p (float, optional): dropout probability. Defaults to 0.0.
softmax_scale (float, optional): scale factor for softmax. Defaults to None,
position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None.
causal (bool, optional): whether to apply causal mask. Defaults to True.
deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False.
target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None.
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,
k,
v,
attention_mask,
query_length=query_length,
is_causal=causal,
dropout=dropout_p,
position_ids=position_ids,
softmax_scale=softmax_scale,
deterministic=deterministic,
target_dtype=target_dtype,
)
if isinstance(context_layer, tuple):
context_layer = context_layer[0]
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
# out e.g., [s/p::h]
return output

View File

@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
@PeftPlugin("lora").register() @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
if model.device.type == "meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
adapter_name_or_path = config.get("adapter_name_or_path") adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path: if adapter_name_or_path:

View File

@@ -17,6 +17,7 @@ import gc
import os import os
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
@@ -84,9 +85,6 @@ class FSDP2Engine:
) )
if self.device_mesh is not None: if self.device_mesh is not None:
try:
self.fsdp_mesh = self.device_mesh["dp"]
except Exception:
self.fsdp_mesh = self.device_mesh self.fsdp_mesh = self.device_mesh
logger.info(f"Using Device Mesh: {self.fsdp_mesh}") logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
@@ -244,23 +242,57 @@ class FSDP2Engine:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers") logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel: def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta": init_mode = getattr(model, "_init_mode", "init_on_default")
if init_mode == "init_on_rank0":
if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights()
if self.rank == 0:
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
else:
full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model)
device = get_current_accelerator()
model.to_empty(device=device)
# Scatter params from Rank 0 into all DTensor shards
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0:
logger.info("init_on_rank0 sync complete.")
elif init_mode == "init_on_meta":
non_persistent_buffers = self._save_non_persistent_buffers(model) non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
model = self.prepare_model(model) model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case # fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers) self._restore_non_persistent_buffers(model, non_persistent_buffers)
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
return model return model
def _load_from_dcp(self, model: HFModel, dcp_path: str): def _load_from_dcp(self, model: HFModel, dcp_path: str):

View File

@@ -0,0 +1,62 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import (
SequenceParallelModelPlugin,
sequence_parallel_loss,
)
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int):
with dist_env(local_rank, world_size, master_port):
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
# Initialize distributed interface with config
dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size}
DistributedInterface(dist_config)
# Now create model engine
model_engine = ModelEngine(model_args=model_args)
# Apply sequence parallel plugin
SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config)
model_inputs = {
"input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
"labels": torch.tensor([[1, 2, 3, 4, 5]]),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
"position_ids": torch.tensor([[1, 2, 3, 4, 5]]),
"loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]),
}
loss = sequence_parallel_loss(model_engine.model, model_inputs)
assert loss is not None
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)])
def test_sequence_parallel_loss(cp_size, dp_size):
master_port = find_available_port()
world_size = cp_size * dp_size
mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)