diff --git a/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_padding_free.yaml b/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_padding_free.yaml new file mode 100644 index 000000000..aa7ab54b0 --- /dev/null +++ b/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_padding_free.yaml @@ -0,0 +1,30 @@ +model: Qwen/Qwen3-0.6B +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 4 +batching_strategy: dynamic_padding_free +flash_attn: flash_attention2 +cutoff_len: 2048 +learning_rate: 1.0e-4 +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml b/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml index b841cca80..2f96a065e 100644 --- a/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml +++ b/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml @@ -20,7 +20,7 @@ train_dataset: data/v1_sft_demo.yaml output_dir: outputs/test_fsdp2 micro_batch_size: 4 batching_strategy: padding_free -flash_attn: fa2 +flash_attn: flash_attention2 cutoff_len: 2048 learning_rate: 1.0e-4 max_steps: 10 diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index 0a0caddd2..9aa160644 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -57,15 +57,12 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments, print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") + model_args, data_args, training_args, sample_args = parsed_args # Seed as early as possible after argument parsing so all downstream # components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state. - for arg in parsed_args: - seed = getattr(arg, "seed", None) - if seed is not None: - set_seed(seed) - break + set_seed(training_args.seed, full_determinism=training_args.full_determinism) - return tuple(parsed_args) + return model_args, data_args, training_args, sample_args if __name__ == "__main__": diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 0b5fc1ff2..200938b9c 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -85,6 +85,10 @@ class TrainingArguments: default=42, metadata={"help": "Random seed that will be set at the beginning of training."}, ) + full_determinism: bool = field( + default=False, + metadata={"help": "Enable full deterministic mode for reproducible distributed training."}, + ) resume_from_checkpoint: str | None = field( default=None, metadata={"help": "Path to a checkpoint directory to resume training from, or 'auto' to find the latest."}, diff --git a/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py b/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py index 85febc773..4bf568e30 100644 --- a/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py +++ b/src/llamafactory/v1/plugins/model_plugins/parallelization/ulysses.py @@ -114,7 +114,6 @@ class UlyssesAttention(torch.nn.Module): # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! # in shape : e.g., [s/p:h:] # (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size) - # scatter 2, gather 1 q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx) k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx) @@ -123,19 +122,24 @@ class UlyssesAttention(torch.nn.Module): if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 - if attention_mask is None: - if position_ids is not None: - attention_mask = torch.ones_like(position_ids).to(torch.int64) - else: - attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device) + if position_ids is not None: + global_position_ids = [ + torch.empty_like(position_ids) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg)) + ] + dist.all_gather(global_position_ids, position_ids, group=self.spg) + position_ids = torch.cat(global_position_ids, dim=-1).contiguous() + attention_mask = None else: - attention_mask = attention_mask.to(torch.int64) + if attention_mask is None: + attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device) + else: + attention_mask = attention_mask.to(torch.int64) - global_attention_mask = [ - torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg)) - ] - dist.all_gather(global_attention_mask, attention_mask, group=self.spg) - attention_mask = torch.cat(global_attention_mask, dim=1) + global_attention_mask = [ + torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg)) + ] + dist.all_gather(global_attention_mask, attention_mask, group=self.spg) + attention_mask = torch.cat(global_attention_mask, dim=1) context_layer = self.attn_fn( q, diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py index dd453093c..fb4c71b4a 100644 --- a/src/llamafactory/v1/utils/helper.py +++ b/src/llamafactory/v1/utils/helper.py @@ -13,22 +13,46 @@ # limitations under the License. +import random + +import numpy as np import torch from transformers import PreTrainedTokenizer from transformers import set_seed as hf_set_seed +from ..accelerator.helper import is_torch_npu_available from ..accelerator.interface import DistributedInterface from .constants import IGNORE_INDEX from .types import BatchInput, ModelInput, Processor, Tensor -def set_seed(seed: int) -> None: +def enable_full_determinism(seed: int) -> None: + """Enable full deterministic mode for reproducible distributed training.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.use_deterministic_algorithms(True, warn_only=True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + if is_torch_npu_available(): + torch.npu.manual_seed(seed) + torch.npu.manual_seed_all(seed) + + +def set_seed(seed: int, full_determinism: bool = False) -> None: """Set seed for reproducibility. Args: seed: Random seed. + full_determinism: Whether to enable full deterministic mode. """ - hf_set_seed(seed) + if full_determinism: + enable_full_determinism(seed) + else: + hf_set_seed(seed) def is_tokenizer(processor: Processor) -> bool: