diff --git a/examples/v1/train_full/train_full_deepspeed.yaml b/examples/v1/train_full/train_full_deepspeed.yaml new file mode 100644 index 000000000..29d9353cd --- /dev/null +++ b/examples/v1/train_full/train_full_deepspeed.yaml @@ -0,0 +1,25 @@ +model: Qwen/Qwen3-0.6B + +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto + +dist_config: + name: deepspeed + config_file: examples/deepspeed/ds_z3_config.json + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/Qwen3-0.6B-deepspeed +micro_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: true +max_steps: 10 + diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 76c3911e1..69660361e 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -76,19 +76,28 @@ class BaseTrainer: if self.args.enable_activation_checkpointing: self.model.gradient_checkpointing_enable({"use_reentrant": False}) - if self.args.dist_config is not None: - shard_need_optimizer = self.args.dist_config.name == "deepspeed" - else: - shard_need_optimizer = False + self._accelerate_engine = None + dist_name = self.args.dist_config.name if self.args.dist_config is not None else None - if shard_need_optimizer: + if dist_name == "deepspeed": + from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin + + self._deepspeed_engine = DistributedPlugin("deepspeed")( + self.model, + self.args.dist_config, + num_micro_batch=self.train_batch_generator.num_micro_batch, + micro_batch_size=self.args.micro_batch_size, + ) self._init_optimizer() - self._shard_model() + self._init_lr_scheduler() + self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare( + self.model, self.optimizer, self.lr_scheduler + ) else: + # fsdp2 / DDP / no dist self._shard_model() self._init_optimizer() - - self._init_lr_scheduler() + self._init_lr_scheduler() def _create_batch_generator(self) -> None: self.train_batch_generator = BatchGenerator( @@ -171,25 +180,35 @@ class BaseTrainer: step_loss = 0 step_valid_tokens = compute_valid_tokens(micro_batches) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) - for micro_batch in micro_batches: + num_micro = len(micro_batches) + for i, micro_batch in enumerate(micro_batches): loss = self.compute_loss(micro_batch) mini_step_valid_tokens = compute_valid_tokens([micro_batch]) # fsdp uses mean reduction so we need to scale the loss by dp_size loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) - loss.backward() + if self._deepspeed_engine is not None: + # deepspeed: set sync_gradients so engine.step() only fires on last micro-batch + self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1 + self._deepspeed_engine.backward(loss) + else: + loss.backward() step_loss += loss.item() - grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() - - # isfinite(): argument 'input' (position 1) must be Tensor, not float - if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] - logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") + if self._deepspeed_engine is not None: + # deepspeed: engine.step() already ran inside backward at the sync boundary + grad_norm = self._deepspeed_engine.get_grad_norm() else: - self.optimizer.step() + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() - self.lr_scheduler.step() - self.optimizer.zero_grad() + # isfinite(): argument 'input' (position 1) must be Tensor, not float + if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] + logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") + else: + self.optimizer.step() + + self.lr_scheduler.step() + self.optimizer.zero_grad() step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) DistributedInterface().sync() @@ -203,17 +222,14 @@ class BaseTrainer: def save_model(self) -> None: """Save the model.""" - model_to_save = self.model.module if hasattr(self.model, "module") else self.model - state_dict = None - if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2": - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"): + from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - state_dict = get_model_state_dict(self.model, options=options) - - if DistributedInterface().get_rank() != 0: - return - - model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict) - self.renderer.processor.save_pretrained(self.args.output_dir) - logger.info_rank0(f"Model saved to {self.args.output_dir}") + DistributedPlugin(self.args.dist_config.name).save_model( + self.model, self.args.output_dir, self.renderer.processor + ) + else: + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") + self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") + logger.info_rank0(f"Model saved to {self.args.output_dir}") diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py index e69de29bb..d839045b3 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py @@ -0,0 +1,129 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSpeed integration via accelerate's built-in capabilities. + +Instead of manually calling deepspeed.initialize() and syncing config, +this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle +initialization, backward, gradient accumulation, and model saving. +""" + +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from accelerate.utils import DeepSpeedPlugin + +from ....utils.logging import get_logger +from ....utils.types import HFModel, Processor + + +logger = get_logger(__name__) + + +class DeepSpeedEngine: + """DeepSpeed integration using accelerate's built-in capabilities. + + This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach + with accelerate's Accelerator + DeepSpeedPlugin, which handles: + - Config syncing (auto values, batch size, lr, etc.) + - deepspeed.initialize() call + - Optimizer / LR scheduler wrapping + - Backward + gradient accumulation boundary + - ZeRO-3 parameter gathering for saving + """ + + def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1): + config_file = dist_config.get("config_file") + if not config_file: + raise ValueError("DeepSpeed config_file is required in dist_config") + + ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file) + + self.accelerator = Accelerator( + deepspeed_plugin=ds_plugin, + gradient_accumulation_steps=num_micro_batch, + ) + + # Resolve "auto" for train_micro_batch_size_per_gpu so that + # accelerate.prepare() does not require a DataLoader to infer it. + ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config + if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"): + ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size + + logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}") + + def shard_model(self, model: HFModel) -> "DeepSpeedEngine": + """No-op shard — actual model wrapping happens in prepare(). + + Returns self so the caller gets the engine instance via the hub interface. + """ + return self + + def prepare( + self, + model: HFModel, + optimizer: torch.optim.Optimizer, + lr_scheduler: Optional[Any] = None, + ) -> tuple[HFModel, torch.optim.Optimizer, Any]: + """Prepare model, optimizer, and lr_scheduler using accelerate. + + Internally calls deepspeed.initialize() and wraps the returned objects. + """ + if lr_scheduler is not None: + model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler) + else: + model, optimizer = self.accelerator.prepare(model, optimizer) + + model._accelerator = self.accelerator # type: ignore[assignment] + + logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate") + return model, optimizer, lr_scheduler + + def backward(self, loss: torch.Tensor) -> None: + """Backward pass using accelerate. + + Delegates to DeepSpeedEngineWrapper.backward() which respects + sync_gradients to control gradient accumulation boundaries. + When sync_gradients=True: engine.backward(loss) + engine.step() + When sync_gradients=False: engine.backward(loss) only + """ + self.accelerator.backward(loss) + + def get_grad_norm(self) -> float: + """Get the global gradient norm from the DeepSpeed engine.""" + engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None) + if engine_wrapper is not None: + return engine_wrapper.engine.get_global_grad_norm() or 0.0 + return 0.0 + + +def save_model(model: HFModel, output_dir: str, processor: Processor) -> None: + """Save model using accelerate's built-in ZeRO-aware utilities. + + Expects model._accelerator to be set during prepare(). + Handles ZeRO-3 parameter gathering automatically via + accelerator.get_state_dict(). + """ + accelerator: Accelerator = model._accelerator # type: ignore[union-attr] + + unwrapped_model = accelerator.unwrap_model(model) + state_dict = accelerator.get_state_dict(model) + + if accelerator.is_main_process: + unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB") + processor.save_pretrained(output_dir, max_shard_size="4GB") + + accelerator.wait_for_everyone() + logger.info_rank0(f"Model saved to {output_dir}") diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index e3c3020aa..dbe2626bf 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -17,24 +17,24 @@ import os import torch import torch.nn as nn +from peft.tuners.lora import LoraLayer from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard, ) -from transformers import PreTrainedModel -from peft.tuners.lora import LoraLayer from ....accelerator.helper import get_current_accelerator from ....accelerator.interface import DistributedInterface from ....utils.logging import get_logger +from ....utils.types import HFModel, Processor logger = get_logger(__name__) -def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None: +def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None: no_split_modules = getattr(model, "_no_split_modules", None) if no_split_modules: if isinstance(no_split_modules, (list, tuple)): @@ -50,6 +50,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None: return None +def save_model(model: HFModel, output_dir: str, processor: Processor) -> None: + if DistributedInterface().get_rank() == 0: + logger.info("Gathering state dict for saving...") + + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + state_dict = get_model_state_dict(model, options=options) + + if DistributedInterface().get_rank() == 0: + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB") + processor.save_pretrained(output_dir, max_shard_size="4GB") + logger.info(f"Model saved to {output_dir}") + + class FSDP2Engine: def __init__(self, dist_config: dict): self.dist_interface = DistributedInterface() @@ -94,12 +108,11 @@ class FSDP2Engine: reduce_dtype=reduce_dtype, cast_forward_inputs=True, ) - def is_lora_module_wrap(self, model) -> bool: return any(isinstance(module, LoraLayer) for module in model.modules()) - def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel: + def prepare_model(self, model: HFModel) -> HFModel: if self.fsdp_mesh is None: logger.warning("No FSDP Mesh available, skipping FSDP wrapping.") return model @@ -115,11 +128,10 @@ class FSDP2Engine: else: logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}") transformer_layer_cls_to_wrap = {layer_cls} - + if self.is_lora_module_wrap(model): lora_modules = [] for module in model.modules(): - if len(list(module.children())) != 0: continue if any(param.requires_grad for param in module.parameters(recurse=False)): @@ -134,7 +146,7 @@ class FSDP2Engine: offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, ) - logger.info(f"Applying FSDP wrap for LoRA layer separately.") + logger.info("Applying FSDP wrap for LoRA layer separately.") for name, module in model.named_modules(): should_wrap = False @@ -179,8 +191,9 @@ class FSDP2Engine: ) return model + @torch.no_grad() - def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None): + def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None): if self.rank == 0: logger.info("Materializing sharded model params...") @@ -200,7 +213,7 @@ class FSDP2Engine: return model - def shard_model(self, model: PreTrainedModel) -> PreTrainedModel: + def shard_model(self, model: HFModel) -> HFModel: if model.device.type == "meta": model = self.prepare_model(model) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) @@ -208,7 +221,7 @@ class FSDP2Engine: model = self.prepare_model(model) return model - def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str): + def _load_from_dcp(self, model: HFModel, dcp_path: str): import torch.distributed.checkpoint as dcp try: @@ -227,7 +240,7 @@ class FSDP2Engine: logger.error(f"Failed to load from DCP: {e}") raise e - def _load_weights_from_hf_checkpoint(self, model, hf_model_path): + def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str): import glob import json diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py index 096cae14e..24b7052c2 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + from ....config.arg_utils import PluginConfig from ....utils.plugin import BasePlugin -from ....utils.types import HFModel + + +if TYPE_CHECKING: + from ....utils.types import HFModel, Processor class DistributedPlugin(BasePlugin): @@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin): @DistributedPlugin("fsdp2").register() -def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel: +def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: from .fsdp2 import FSDP2Engine return FSDP2Engine(dist_config).shard_model(model) +@DistributedPlugin("fsdp2").register("save_model") +def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None: + from .fsdp2 import save_model + + return save_model(model, output_dir, processor) + + @DistributedPlugin("deepspeed").register() -def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel: - return model +def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: + from .deepspeed import DeepSpeedEngine + + return DeepSpeedEngine( + dist_config, + num_micro_batch=kwargs.get("num_micro_batch"), + micro_batch_size=kwargs.get("micro_batch_size"), + ).shard_model(model) + + +@DistributedPlugin("deepspeed").register("save_model") +def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None: + from .deepspeed import save_model + + return save_model(model, output_dir, processor)