mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] fix padding free with sp (#10513)
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
quant_config: null
|
||||
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: dynamic_padding_free
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 4
|
||||
batching_strategy: padding_free
|
||||
flash_attn: fa2
|
||||
flash_attn: flash_attention2
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
max_steps: 10
|
||||
|
||||
@@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
model_args, data_args, training_args, sample_args = parsed_args
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
set_seed(training_args.seed, full_determinism=training_args.full_determinism)
|
||||
|
||||
return tuple(parsed_args)
|
||||
return model_args, data_args, training_args, sample_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -85,6 +85,10 @@ class TrainingArguments:
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
full_determinism: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable full deterministic mode for reproducible distributed training."},
|
||||
)
|
||||
resume_from_checkpoint: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."},
|
||||
|
||||
@@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module):
|
||||
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||
# in shape : e.g., [s/p:h:]
|
||||
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||
|
||||
# scatter 2, gather 1
|
||||
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||
@@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** -0.5
|
||||
|
||||
if attention_mask is None:
|
||||
if position_ids is not None:
|
||||
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||
else:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
if position_ids is not None:
|
||||
global_position_ids = [
|
||||
torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_position_ids, position_ids, group=self.spg)
|
||||
position_ids = torch.cat(global_position_ids, dim=-1).contiguous()
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(torch.int64)
|
||||
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
global_attention_mask = [
|
||||
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||
]
|
||||
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||
|
||||
context_layer = self.attn_fn(
|
||||
q,
|
||||
|
||||
@@ -13,22 +13,46 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.helper import is_torch_npu_available
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
def enable_full_determinism(seed: int) -> None:
|
||||
"""Enable full deterministic mode for reproducible distributed training."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.enabled = False
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed(seed)
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def set_seed(seed: int, full_determinism: bool = False) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
full_determinism: Whether to enable full deterministic mode.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
if full_determinism:
|
||||
enable_full_determinism(seed)
|
||||
else:
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user