[v1] support deepspeed (#10181)

This commit is contained in:
浮梦
2026-02-12 17:24:30 +08:00
committed by GitHub
parent 675ce8cc7f
commit 5c52afa30d
5 changed files with 257 additions and 47 deletions

View File

@@ -0,0 +1,25 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
dist_config:
name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10

View File

@@ -76,19 +76,28 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
self._accelerate_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if shard_need_optimizer:
if dist_name == "deepspeed":
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self._deepspeed_engine = DistributedPlugin("deepspeed")(
self.model,
self.args.dist_config,
num_micro_batch=self.train_batch_generator.num_micro_batch,
micro_batch_size=self.args.micro_batch_size,
)
self._init_optimizer()
self._shard_model()
self._init_lr_scheduler()
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
self.model, self.optimizer, self.lr_scheduler
)
else:
# fsdp2 / DDP / no dist
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
@@ -171,25 +180,35 @@ class BaseTrainer:
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
if self._deepspeed_engine is not None:
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
self._deepspeed_engine.backward(loss)
else:
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
if self._deepspeed_engine is not None:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
self.optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
@@ -203,17 +222,14 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
state_dict = None
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
if DistributedInterface().get_rank() != 0:
return
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")
DistributedPlugin(self.args.dist_config.name).save_model(
self.model, self.args.output_dir, self.renderer.processor
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -0,0 +1,129 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSpeed integration via accelerate's built-in capabilities.
Instead of manually calling deepspeed.initialize() and syncing config,
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
initialization, backward, gradient accumulation, and model saving.
"""
from typing import Any, Optional
import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
class DeepSpeedEngine:
"""DeepSpeed integration using accelerate's built-in capabilities.
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
- Config syncing (auto values, batch size, lr, etc.)
- deepspeed.initialize() call
- Optimizer / LR scheduler wrapping
- Backward + gradient accumulation boundary
- ZeRO-3 parameter gathering for saving
"""
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
config_file = dist_config.get("config_file")
if not config_file:
raise ValueError("DeepSpeed config_file is required in dist_config")
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
self.accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
gradient_accumulation_steps=num_micro_batch,
)
# Resolve "auto" for train_micro_batch_size_per_gpu so that
# accelerate.prepare() does not require a DataLoader to infer it.
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
"""No-op shard — actual model wrapping happens in prepare().
Returns self so the caller gets the engine instance via the hub interface.
"""
return self
def prepare(
self,
model: HFModel,
optimizer: torch.optim.Optimizer,
lr_scheduler: Optional[Any] = None,
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
"""Prepare model, optimizer, and lr_scheduler using accelerate.
Internally calls deepspeed.initialize() and wraps the returned objects.
"""
if lr_scheduler is not None:
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
else:
model, optimizer = self.accelerator.prepare(model, optimizer)
model._accelerator = self.accelerator # type: ignore[assignment]
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
return model, optimizer, lr_scheduler
def backward(self, loss: torch.Tensor) -> None:
"""Backward pass using accelerate.
Delegates to DeepSpeedEngineWrapper.backward() which respects
sync_gradients to control gradient accumulation boundaries.
When sync_gradients=True: engine.backward(loss) + engine.step()
When sync_gradients=False: engine.backward(loss) only
"""
self.accelerator.backward(loss)
def get_grad_norm(self) -> float:
"""Get the global gradient norm from the DeepSpeed engine."""
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
if engine_wrapper is not None:
return engine_wrapper.engine.get_global_grad_norm() or 0.0
return 0.0
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
"""Save model using accelerate's built-in ZeRO-aware utilities.
Expects model._accelerator to be set during prepare().
Handles ZeRO-3 parameter gathering automatically via
accelerator.get_state_dict().
"""
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
unwrapped_model = accelerator.unwrap_model(model)
state_dict = accelerator.get_state_dict(model)
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
accelerator.wait_for_everyone()
logger.info_rank0(f"Model saved to {output_dir}")

View File

@@ -17,24 +17,24 @@ import os
import torch
import torch.nn as nn
from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from transformers import PreTrainedModel
from peft.tuners.lora import LoraLayer
from ....accelerator.helper import get_current_accelerator
from ....accelerator.interface import DistributedInterface
from ....utils.logging import get_logger
from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
if isinstance(no_split_modules, (list, tuple)):
@@ -50,6 +50,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
return None
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
if DistributedInterface().get_rank() == 0:
logger.info("Gathering state dict for saving...")
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(model, options=options)
if DistributedInterface().get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
processor.save_pretrained(output_dir, max_shard_size="4GB")
logger.info(f"Model saved to {output_dir}")
class FSDP2Engine:
def __init__(self, dist_config: dict):
self.dist_interface = DistributedInterface()
@@ -94,12 +108,11 @@ class FSDP2Engine:
reduce_dtype=reduce_dtype,
cast_forward_inputs=True,
)
def is_lora_module_wrap(self, model) -> bool:
return any(isinstance(module, LoraLayer) for module in model.modules())
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
def prepare_model(self, model: HFModel) -> HFModel:
if self.fsdp_mesh is None:
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
return model
@@ -115,11 +128,10 @@ class FSDP2Engine:
else:
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
transformer_layer_cls_to_wrap = {layer_cls}
if self.is_lora_module_wrap(model):
lora_modules = []
for module in model.modules():
if len(list(module.children())) != 0:
continue
if any(param.requires_grad for param in module.parameters(recurse=False)):
@@ -134,7 +146,7 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
)
logger.info(f"Applying FSDP wrap for LoRA layer separately.")
logger.info("Applying FSDP wrap for LoRA layer separately.")
for name, module in model.named_modules():
should_wrap = False
@@ -179,8 +191,9 @@ class FSDP2Engine:
)
return model
@torch.no_grad()
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
if self.rank == 0:
logger.info("Materializing sharded model params...")
@@ -200,7 +213,7 @@ class FSDP2Engine:
return model
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta":
model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
@@ -208,7 +221,7 @@ class FSDP2Engine:
model = self.prepare_model(model)
return model
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
def _load_from_dcp(self, model: HFModel, dcp_path: str):
import torch.distributed.checkpoint as dcp
try:
@@ -227,7 +240,7 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json

View File

@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ....config.arg_utils import PluginConfig
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
if TYPE_CHECKING:
from ....utils.types import HFModel, Processor
class DistributedPlugin(BasePlugin):
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
@DistributedPlugin("fsdp2").register()
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .fsdp2 import FSDP2Engine
return FSDP2Engine(dist_config).shard_model(model)
@DistributedPlugin("fsdp2").register("save_model")
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
from .fsdp2 import save_model
return save_model(model, output_dir, processor)
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
return model
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(
dist_config,
num_micro_batch=kwargs.get("num_micro_batch"),
micro_batch_size=kwargs.get("micro_batch_size"),
).shard_model(model)
@DistributedPlugin("deepspeed").register("save_model")
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
from .deepspeed import save_model
return save_model(model, output_dir, processor)