diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index bac346890..97289c94a 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -269,26 +269,13 @@ class BaseTrainer: # deepspeed: engine.step() already ran inside backward at the sync boundary grad_norm = self._deepspeed_engine.get_grad_norm() else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() + if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1: - from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm + grad_norm = grad_norm**2 + grad_norm = DistributedInterface().all_reduce(grad_norm, op=ReduceOp.SUM, dim=Dim.CP) + grad_norm = grad_norm**0.5 - parameters = self.model.parameters() - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - else: - parameters = list(parameters) - grads = [p.grad for p in parameters if p.grad is not None] - grad_norm = _get_total_norm(grads) - grad_norm = grad_norm.to(self.device) - _clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm) - if isinstance(grad_norm, torch.distributed._tensor.DTensor): - grad_norm = grad_norm.full_tensor().item() - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args.max_grad_norm - ).item() - - # isfinite(): argument 'input' (position 1) must be Tensor, not float if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}") else: diff --git a/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py b/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py index 35d7b0323..8d5073bb4 100644 --- a/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py +++ b/src/llamafactory/v1/plugins/model_plugins/parallelization/sequence_parallel.py @@ -175,9 +175,9 @@ def sequence_parallel_loss(model, model_inputs): global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)] dist.all_gather(global_labels, labels, group=cp_group) labels = torch.cat(global_labels, dim=1).contiguous() - shift_labels = labels[..., 1:].view(-1).contiguous() + shift_labels = labels[..., 1:].contiguous() shift_labels = F.pad(shift_labels, (0, 1), value=-100) - shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous() + shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=1)[cp_rank].contiguous() # use all_gather to collect loss_weights from all sequence parallel processes loss_weights = model_inputs["loss_weights"] @@ -186,7 +186,8 @@ def sequence_parallel_loss(model, model_inputs): shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous() shift_loss_weights = shift_loss_weights[..., 1:].contiguous() - shift_logits = logits.view(shift_labels.size(0), -1).contiguous() + shift_logits = logits.view(-1, logits.size(-1)).contiguous() + shift_labels = shift_labels.view(-1).contiguous() # use all_gather to collect log_probs from all sequence parallel processes log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1) diff --git a/tests_v1/plugins/model_plugins/test_ulysses_cp.py b/tests_v1/plugins/model_plugins/test_ulysses_cp.py index 8de66c027..746c2a0c2 100644 --- a/tests_v1/plugins/model_plugins/test_ulysses_cp.py +++ b/tests_v1/plugins/model_plugins/test_ulysses_cp.py @@ -27,7 +27,9 @@ from llamafactory.v1.utils.env import find_available_port from llamafactory.v1.utils.pytest import dist_env -def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int): +def _test_sequence_parallel_loss( + local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int, batch_size: int +): with dist_env(local_rank, world_size, master_port): model_args = ModelArguments(model="llamafactory/tiny-random-qwen3") @@ -41,12 +43,13 @@ def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: # Apply sequence parallel plugin SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config) + input_ids = torch.arange(1, batch_size * 5 + 1, dtype=torch.long).view(batch_size, 5) model_inputs = { - "input_ids": torch.tensor([[1, 2, 3, 4, 5]]), - "labels": torch.tensor([[1, 2, 3, 4, 5]]), - "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), - "position_ids": torch.tensor([[1, 2, 3, 4, 5]]), - "loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]), + "input_ids": input_ids, + "labels": input_ids.clone(), + "attention_mask": torch.ones_like(input_ids), + "position_ids": torch.arange(1, 6, dtype=torch.long).repeat(batch_size, 1), + "loss_weights": torch.ones(batch_size, 5), } loss = sequence_parallel_loss(model_engine.model, model_inputs) @@ -55,8 +58,10 @@ def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: @pytest.mark.runs_on(["cuda", "npu"]) @pytest.mark.require_distributed(2) -@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)]) -def test_sequence_parallel_loss(cp_size, dp_size): +@pytest.mark.parametrize(("cp_size", "dp_size", "batch_size"), [(2, 1, 1), (2, 1, 2)]) +def test_sequence_parallel_loss(cp_size, dp_size, batch_size): master_port = find_available_port() world_size = cp_size * dp_size - mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size) + mp.spawn( + _test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size, batch_size), nprocs=world_size + )