[v1] fix padding free with sp (#10513)

This commit is contained in:
jiaqiw09
2026-05-26 23:49:21 +08:00
committed by GitHub
parent 8e68764b65
commit 01398eb18d
6 changed files with 80 additions and 21 deletions

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: dynamic_padding_free
flash_attn: flash_attention2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2
micro_batch_size: 4 micro_batch_size: 4
batching_strategy: padding_free batching_strategy: padding_free
flash_attn: fa2 flash_attn: flash_attention2
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
max_steps: 10 max_steps: 10

View File

@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
model_args, data_args, training_args, sample_args = parsed_args
# Seed as early as possible after argument parsing so all downstream # Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state. # components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args: set_seed(training_args.seed, full_determinism=training_args.full_determinism)
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
return tuple(parsed_args) return model_args, data_args, training_args, sample_args
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42, default=42,
metadata={"help": "Random seed that will be set at the beginning of training."}, metadata={"help": "Random seed that will be set at the beginning of training."},
) )
full_determinism: bool = field(
default=False,
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
)
resume_from_checkpoint: str | None = field( resume_from_checkpoint: str | None = field(
default=None, default=None,
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."}, metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},

View File

@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:] # in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1 # scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx) q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx) k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
@@ -123,10 +122,15 @@ class UlyssesAttention(torch.nn.Module):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5 softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None: if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64) global_position_ids = [
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_position_ids, position_ids, group=self.spg)
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
attention_mask = None
else: else:
if attention_mask is None:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device) attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else: else:
attention_mask = attention_mask.to(torch.int64) attention_mask = attention_mask.to(torch.int64)

View File

@@ -13,21 +13,45 @@
# limitations under the License. # limitations under the License.
import random
import numpy as np
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed from transformers import set_seed as hf_set_seed
from ..accelerator.helper import is_torch_npu_available
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None: def enable_full_determinism(seed: int) -> None:
"""Enable full deterministic mode for reproducible distributed training."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
if is_torch_npu_available():
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
def set_seed(seed: int, full_determinism: bool = False) -> None:
"""Set seed for reproducibility. """Set seed for reproducibility.
Args: Args:
seed: Random seed. seed: Random seed.
full_determinism: Whether to enable full deterministic mode.
""" """
if full_determinism:
enable_full_determinism(seed)
else:
hf_set_seed(seed) hf_set_seed(seed)