mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-08 12:46:06 +08:00
[v1] add sft (#9752)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
@@ -73,3 +74,20 @@ def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchIn
|
||||
padded_samples.append(padded_sample)
|
||||
|
||||
return padded_samples
|
||||
|
||||
|
||||
def compute_valid_tokens(batches: list[BatchInput]) -> int:
|
||||
"""Compute valid tokens in batches.
|
||||
|
||||
Args:
|
||||
batches: Batches.
|
||||
|
||||
Returns:
|
||||
Number of valid tokens.
|
||||
"""
|
||||
device = DistributedInterface().current_device
|
||||
return sum(
|
||||
(batch["labels"].to(device, non_blocking=True) != IGNORE_INDEX).sum().item()
|
||||
for batch in batches
|
||||
if "labels" in batch
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user