[v1] add renderer ut (#9722)

This commit is contained in:
Yaowei Zheng
2026-01-07 02:06:07 +08:00
committed by GitHub
parent ea0b4e2466
commit d22de0d4bf
13 changed files with 420 additions and 249 deletions

View File

@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
return loss
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def _eaft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -712,13 +719,13 @@ def _eaft_cross_entropy(
with torch.no_grad():
source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss