[v1] fix padding free with sp (#10513)

This commit is contained in:
jiaqiw09
2026-05-26 23:49:21 +08:00
committed by GitHub
parent 8e68764b65
commit 01398eb18d
6 changed files with 80 additions and 21 deletions

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: fa2
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
return tuple(parsed_args)
return model_args, data_args, training_args, sample_args
if __name__ == "__main__":

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field(
default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
if position_ids is not None:
global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else:
attention_mask = attention_mask.to(torch.int64)
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,

View File

@@ -13,22 +13,46 @@
# limitations under the License.
import random
import numpy as np
import torch
from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
"""
hf_set_seed(seed)
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool: