[v1] support deepspeed (#10181)

This commit is contained in:
浮梦
2026-02-12 17:24:30 +08:00
committed by GitHub
parent 675ce8cc7f
commit 5c52afa30d
5 changed files with 257 additions and 47 deletions

View File

@@ -76,19 +76,28 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False})
if self.args.dist_config is not None:
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
else:
shard_need_optimizer = False
self._accelerate_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if shard_need_optimizer:
if dist_name == "deepspeed":
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
self._deepspeed_engine = DistributedPlugin("deepspeed")(
self.model,
self.args.dist_config,
num_micro_batch=self.train_batch_generator.num_micro_batch,
micro_batch_size=self.args.micro_batch_size,
)
self._init_optimizer()
self._shard_model()
self._init_lr_scheduler()
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
self.model, self.optimizer, self.lr_scheduler
)
else:
# fsdp2 / DDP / no dist
self._shard_model()
self._init_optimizer()
self._init_lr_scheduler()
self._init_lr_scheduler()
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
@@ -171,25 +180,35 @@ class BaseTrainer:
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
for micro_batch in micro_batches:
num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
loss.backward()
if self._deepspeed_engine is not None:
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
self._deepspeed_engine.backward(loss)
else:
loss.backward()
step_loss += loss.item()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
if self._deepspeed_engine is not None:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
self.optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
else:
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
@@ -203,17 +222,14 @@ class BaseTrainer:
def save_model(self) -> None:
"""Save the model."""
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
state_dict = None
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
state_dict = get_model_state_dict(self.model, options=options)
if DistributedInterface().get_rank() != 0:
return
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
self.renderer.processor.save_pretrained(self.args.output_dir)
logger.info_rank0(f"Model saved to {self.args.output_dir}")
DistributedPlugin(self.args.dist_config.name).save_model(
self.model, self.args.output_dir, self.renderer.processor
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")