mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] support ulysses cp for fsdp2 (#10262)
This commit is contained in:
@@ -71,6 +71,7 @@ class BaseTrainer:
|
||||
# cached variables
|
||||
self.device = DistributedInterface().current_device
|
||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
|
||||
self.model_input_names = self.renderer.processor.model_input_names
|
||||
|
||||
self._create_batch_generator()
|
||||
@@ -114,6 +115,21 @@ class BaseTrainer:
|
||||
# Callbacks: TrainerState tracks progress across the full run.
|
||||
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
if model.config.model_type == "qwen3_5":
|
||||
raise RuntimeError(
|
||||
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
|
||||
)
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
@@ -172,7 +188,7 @@ class BaseTrainer:
|
||||
"""
|
||||
batch_size, _ = batch["labels"].shape
|
||||
model_inputs = {
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||
outputs: ModelOutput = model(**model_inputs)
|
||||
@@ -206,7 +222,14 @@ class BaseTrainer:
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
num_micro = len(micro_batches)
|
||||
for i, micro_batch in enumerate(micro_batches):
|
||||
loss = self.compute_loss(micro_batch)
|
||||
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import (
|
||||
SequenceParallelLossPlugin,
|
||||
)
|
||||
|
||||
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
|
||||
else:
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
@@ -223,7 +246,24 @@ class BaseTrainer:
|
||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||
|
||||
parameters = self.model.parameters()
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
parameters = list(parameters)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
grad_norm = _get_total_norm(grads)
|
||||
grad_norm = grad_norm.to(self.device)
|
||||
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
|
||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||
grad_norm = grad_norm.full_tensor().item()
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.args.max_grad_norm
|
||||
).item()
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||
|
||||
@@ -146,6 +146,8 @@ class Renderer:
|
||||
for sample in samples:
|
||||
if "messages" in sample:
|
||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||
if "position_ids" not in model_input:
|
||||
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
|
||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||
|
||||
Reference in New Issue
Block a user