mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-03 18:25:59 +08:00
[v1] add renderer ut (#9722)
This commit is contained in:
@@ -634,7 +634,9 @@ def get_batch_logps(
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
def dft_loss_func(
|
||||
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||
):
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
|
||||
|
||||
def _dft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
) -> "torch.Tensor":
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
||||
|
||||
|
||||
def _eaft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -712,13 +719,13 @@ def _eaft_cross_entropy(
|
||||
|
||||
with torch.no_grad():
|
||||
source_detached = source[valid_mask].detach()
|
||||
|
||||
|
||||
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||
log_probs_topk = topk_val - logsumexp_topk
|
||||
probs_topk = torch.exp(log_probs_topk)
|
||||
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||
|
||||
|
||||
entropy_term = entropy_approx / 3.0
|
||||
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||
|
||||
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
|
||||
loss = total_loss / num_items_in_batch
|
||||
else:
|
||||
loss = weighted_losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user