[v1] support ulysses cp for fsdp2 (#10262)

This commit is contained in:
sunyi0505
2026-03-27 16:22:48 +08:00
committed by GitHub
parent df2e6edb7e
commit b5afabe3d2
8 changed files with 552 additions and 7 deletions

View File

@@ -71,6 +71,7 @@ class BaseTrainer:
# cached variables
self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator()
@@ -114,6 +115,21 @@ class BaseTrainer:
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
if model.config.model_type == "qwen3_5":
raise RuntimeError(
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
)
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
@@ -172,7 +188,7 @@ class BaseTrainer:
"""
batch_size, _ = batch["labels"].shape
model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
}
labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs)
@@ -206,7 +222,14 @@ class BaseTrainer:
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch)
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from ..plugins.model_plugins.parallelization.sequence_parallel import (
SequenceParallelLossPlugin,
)
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
else:
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
@@ -223,7 +246,24 @@ class BaseTrainer:
# deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
parameters = self.model.parameters()
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
grad_norm = _get_total_norm(grads)
grad_norm = grad_norm.to(self.device)
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
grad_norm = grad_norm.full_tensor().item()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]

View File

@@ -146,6 +146,8 @@ class Renderer:
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
if "position_ids" not in model_input:
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))