mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-03 18:25:59 +08:00
[algo] add ASFT (#10174)
This commit is contained in:
@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
@@ -681,6 +682,88 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def asft_loss_func(
|
||||
outputs,
|
||||
labels: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
asft_alpha: float = 0.1,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
|
||||
logits = logits.float()
|
||||
|
||||
# shift for causal LM
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
|
||||
|
||||
vocab_size = shift_logits.size(-1)
|
||||
|
||||
# flatten
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
|
||||
return _asft_cross_entropy(
|
||||
policy_logits=shift_logits,
|
||||
policy_labels=shift_labels,
|
||||
ref_logits=shift_ref_logits,
|
||||
asft_alpha=asft_alpha,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
|
||||
def _asft_cross_entropy(
|
||||
policy_logits: torch.Tensor,
|
||||
policy_labels: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
asft_alpha: float = 0.1,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
dft_loss = _dft_cross_entropy(
|
||||
policy_logits,
|
||||
policy_labels,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
kl_loss = _kl_divergence(
|
||||
policy_logits,
|
||||
ref_logits,
|
||||
policy_labels,
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
return dft_loss + asft_alpha * kl_loss
|
||||
|
||||
|
||||
def _kl_divergence(
|
||||
policy_logits: torch.Tensor,
|
||||
ref_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
# log p(y|x)
|
||||
log_p = F.log_softmax(policy_logits, dim=-1)
|
||||
|
||||
# q(y|x)
|
||||
q = F.softmax(ref_logits, dim=-1)
|
||||
|
||||
# token-wise KL
|
||||
kl = F.kl_div(
|
||||
log_p,
|
||||
q,
|
||||
reduction="none",
|
||||
).sum(dim=-1) # [N]
|
||||
|
||||
# mask padding tokens
|
||||
mask = (labels != ignore_index).float()
|
||||
|
||||
return (kl * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
|
||||
Reference in New Issue
Block a user