[v1] add sft (#9752)

This commit is contained in:
Yaowei Zheng
2026-01-12 03:15:01 +08:00
committed by GitHub
parent 4d3621e3d3
commit 958b9c3468
29 changed files with 439 additions and 305 deletions

View File

@@ -16,6 +16,7 @@
import torch
from transformers import PreTrainedTokenizer
from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor
@@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn
padded_samples.append(padded_sample)
return padded_samples
def compute_valid_tokens(batches: list[BatchInput]) -> int:
"""Compute valid tokens in batches.
Args:
batches: Batches.
Returns:
Number of valid tokens.
"""
device = DistributedInterface().current_device
return sum(
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
for batch in batches
if "labels" in batch
)