[v1&WIP] dataloader init (#9645)

This commit is contained in:
Kingsley
2025-12-23 16:29:47 +08:00
committed by GitHub
parent 7901b2f32e
commit 1c8a42d2f8
7 changed files with 981 additions and 18 deletions

View File

@@ -5,6 +5,7 @@ datasets>=2.16.0,<=4.0.0
accelerate>=1.3.0,<=1.11.0
peft>=0.14.0,<=0.17.1
trl>=0.8.6,<=0.9.6
torchdata
# gui
gradio>=4.38.0,<=5.45.0
matplotlib>=3.7.0

View File

@@ -0,0 +1,277 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import sys
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from typing import Optional
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from ..utils.batching_queue import BaseBatchingQueue
from ..utils.logging import get_logger
from ..utils.types import Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
logger = get_logger(__name__)
# base dataloader
class DistributedDataloader(StatefulDataLoader):
"""Base Distributed DataLoader."""
dataset: "TorchDataset"
sampler: "StatefulDistributedSampler"
def set_epoch(self, epoch: int) -> None:
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
@dataclass
class BaseDataLoader:
"""Default DataLoader."""
processor: Processor
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def init_dataloader(self) -> None:
### init dataloader
pass
def __iter__(self) -> Iterator:
pass
def __next__(self) -> any:
pass
@dataclass
class DataLoader:
"""Default DataLoader."""
processor: "Processor"
dataloader: "DistributedDataloader"
batching_queue: "BaseBatchingQueue"
collate_fn: "DataCollator"
num_micro_batch: int = 1
length: int = 0
drop_last: bool = True
def __init__(
self,
dataloader: any,
collate_fn: "DataCollator",
num_micro_batch: int = 1,
length: int = 0,
drop_last: bool = True,
batching_queue: Optional["BaseBatchingQueue"] = None,
) -> None:
self.batching_queue = batching_queue
self.num_micro_batch = num_micro_batch
self.step = 0
self._collate_fn = collate_fn
self._dataloader = dataloader
self._drop_last = drop_last
self._data_iter: Iterator
self._resume = False
self._batch_data_iter: Generator
if length > 0:
self._length = length
elif length == -1:
self._length = sys.maxsize
else:
self._length = len(self._dataloader)
def __len__(self):
return self._length
def __iter__(self) -> Iterator:
if not self._resume:
self.step = 0
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
self._resume = False
return self
def __next__(self):
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
def origin_batch_data_generator(self):
"""Standard pass-through generator if do not use batching queue."""
while True:
if self._length > 0 and self.step >= self._length:
return
try:
batch = []
data = next(self._data_iter)
# split data into micro batches
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
if self.step < self._length:
# Restart iterator to fill the requested length
self._data_iter = iter(self._dataloader)
try:
batch = []
data = next(self._data_iter)
for i in range(0, len(data), self.num_micro_batch):
micro_batch = data[i : i + self.num_micro_batch]
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
yield batch
self.step += 1
except StopIteration:
return
else:
return
except Exception as e:
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
raise
def batch_data_generator(self):
if self.batching_queue is None:
yield from self.origin_batch_data_generator()
return
batch = []
while True:
if self._length and self.step >= self._length:
return
if self.batching_queue.is_full_filled():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
try:
processing_item = next(self._data_iter)
except Exception as e:
if isinstance(e, StopIteration):
if self.step < self._length:
# call iter until reach length
self._data_iter = iter(self._dataloader)
processing_item = next(self._data_iter)
elif not self._drop_last and not self.batching_queue.empty():
while not self.batching_queue.empty():
micro_batch = self.batching_queue.get_micro_batch(self.step)
if self._collate_fn:
micro_batch = self._collate_fn(micro_batch)
batch.append(micro_batch)
if len(batch) == self.num_micro_batch:
yield batch
self.step += 1
batch = []
while len(batch) < self.num_micro_batch:
padding_batch = copy.deepcopy(micro_batch)
padding_batch["is_padded"] = True
batch.append(padding_batch)
yield batch
self.step += 1
return
else:
return
else:
logger.error(f"DataLoader iter data exception: {e}")
raise
# put processing_item to buffer
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_queue.put_item(item)
def state_dict(self):
# save state
state = self.__dict__.copy()
# remove internal fields
for k in list(state.keys()):
if k.startswith("_"):
del state[k]
# save dataloader state
if hasattr(self._dataloader, "state_dict"):
state["dataloader_state"] = self._dataloader.state_dict()
elif hasattr(self._dataloader, "__getstate__"):
state["dataloader_state"] = self._dataloader.__getstate__()
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy and hasattr(batching_strategy, "state_dict"):
state["batching_strategy_state"] = batching_strategy.state_dict()
if "batching_strategy" in state:
del state["batching_strategy"]
return copy.deepcopy(state)
def load_state_dict(self, state: dict[str, any]):
if state["num_micro_batch"] != self.num_micro_batch:
logger.warning(
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
)
del state["num_micro_batch"]
self.__dict__.update(state)
self._resume = True
if hasattr(self._dataloader, "load_state_dict"):
self._dataloader.load_state_dict(state["dataloader_state"])
elif hasattr(self._dataloader, "__getstate__"):
self._dataloader.__setstate__(state["dataloader_state"])
if "batching_strategy_state" in state:
batching_strategy = getattr(self, "batching_strategy", None)
if batching_strategy:
batching_strategy.load_state_dict(state["batching_strategy_state"])
del state["batching_strategy_state"]
self._data_iter = iter(self._dataloader)
self._batch_data_iter = self.batch_data_generator()
def set_epoch(self, epoch: int) -> None:
if hasattr(self._dataloader, "set_epoch"):
self._dataloader.set_epoch(epoch)

View File

@@ -12,36 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from ...utils.types import Processor, Tensor, TorchDataset
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate
from ....extras.constants import IGNORE_INDEX
from ...plugins.data_plugins.template import Template
from ...utils.types import Processor, Tensor
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
"""Convert sequence lengths to cumulative sequence lengths."""
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
processor: "Processor" # processor name -> map to encode_messages function
def __post_init__(self):
# callback for text tokenizer
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
batch = defaultdict(list)
# batching features
for feature in features:
for key in feature.keys():
batch[key].append(feature[key])
for key in batch.keys():
# process padding features
if key in ["input_ids", "attention_mask", "position_ids"]:
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
elif key in ["labels"]:
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
else:
batch[key] = default_collate(batch[key])
return batch
# sft: messages
# dpo: chosen_messages, rejected_messages
class DataLoader:
"""Default DataLoader."""
@dataclass
class DefaultCollator(DataCollator):
"""Example for now."""
def __init__(self, dataset: TorchDataset) -> None:
self.dataset = dataset
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
processor: "Processor" # processor name -> map to encode_messages function
template: "Template"
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
features = []
# Check if data is already tokenized (contains input_ids)
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
for feature in messages:
if not isinstance(feature, dict):
raise ValueError(f"Expected dict but got {type(feature)}")
tensor_feature = {
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
for k, v in feature.items()
}
features.append(tensor_feature)
else:
# raw messages need to be encoded
for message in messages:
encoded_message = self.template.encode_messages(self.tokenizer, message)
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
features.append(encoded_message)
return super().__call__(features)
@dataclass
class PairwiseCollator(DataCollator):
pass
@dataclass
class DataCollatorWithPacking(DefaultCollator):
"""Data collator with packing."""
processor: "Processor"
template: "Template"
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
batch = {"cu_seqlens": len2culen(seqlens)}
for input_name in features[0].keys():
if input_name in ("input_ids", "attention_mask", "labels"):
batch[input_name] = torch.cat([feature[input_name] for feature in features])
else:
batch[input_name] = default_collate([feature[input_name] for feature in features])
return batch

View File

@@ -14,6 +14,7 @@
from dataclasses import dataclass
from typing import Union
@dataclass
@@ -22,5 +23,112 @@ class Template:
assistant_template: str
system_template: str
def render_message(self, message: "dict[str, str]") -> str:
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -0,0 +1,220 @@
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
class DynamicBatchSizeBuffer:
"""A buffer to store samples for dynamic batch size."""
def __init__(self):
self._buffer: list[dict[str, any]] = []
self._buffer_sample_lengths: list[int] = []
self._deleted_indices: set[int] = set()
self._current_index: int = 0
self._total_token_count: int = 0
def append(self, item: dict[str, any]) -> None:
"""Append a sample to the buffer.
Args:
item: A sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self._buffer.append(item)
sample_length = int(item["attention_mask"].sum().item())
self._buffer_sample_lengths.append(sample_length)
self._total_token_count += sample_length
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
"""Get samples from the buffer that fit within the token budget.
Args:
max_tokens_per_iteration: Maximum number of tokens to retrieve.
force: If True, the first available sample will be returned even
if it exceeds the token budget.
Returns:
A list of samples that fit within the token budget.
Raises:
AssertionError: If no samples are found (should not happen in normal operation).
"""
cum_seq_len = 0
samples = []
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
if self._current_index in self._deleted_indices:
self._current_index += 1
continue
seq_len = self._buffer_sample_lengths[self._current_index]
remaining_tokens = max_tokens_per_iteration - cum_seq_len
# Check if we can add this sample
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
if can_add:
cum_seq_len += seq_len
samples.append(self._buffer[self._current_index])
self._deleted_indices.add(self._current_index)
self._current_index += 1
assert len(samples) > 0, "No samples found in buffer"
return samples
def __len__(self) -> int:
"""Return the number of samples in the buffer."""
return len(self._buffer)
@property
def total_token_count(self) -> int:
"""Return the total number of tokens in the buffer."""
return self._total_token_count
def flush(self) -> None:
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
self._total_token_count -= tokens_to_remove
buffer_length = len(self._buffer)
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
self._buffer_sample_lengths = [
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
]
self._current_index = 0
self._deleted_indices.clear()
class BaseBatchingQueue(ABC):
"""Base class for batching queue."""
@abstractmethod
def is_full_filled(self) -> bool:
raise NotImplementedError("Subclasses must implement `is_full_filled`")
@abstractmethod
def put_item(self, item: dict[str, any]) -> None:
raise NotImplementedError("Subclasses must implement `put_item`")
@abstractmethod
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
@abstractmethod
def empty(self) -> bool:
raise NotImplementedError("Subclasses must implement `empty`")
class IdentityPacker:
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
def __call__(self, samples):
return samples
def get_token_num_to_request(self, cur_step, warmup):
return (
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
+ self.bsz_warmup_init_mbtoken
if warmup
else self.token_micro_bsz
)
class TextBatchingQueue(BaseBatchingQueue):
"""Batching text queue for text data."""
def __init__(
self,
token_micro_bsz,
buffer_size: int = 500,
bsz_warmup_steps: int = -1,
bsz_warmup_init_mbtoken: int = 200,
) -> None:
super().__init__()
self._step = 0
self.token_micro_bsz = token_micro_bsz
self.bsz_warmup_steps = bsz_warmup_steps
self.buffer_size = buffer_size # minimum samples in buffer
self.buffer = DynamicBatchSizeBuffer()
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
assert self.bsz_warmup_init_mbtoken >= 0
self.packer = IdentityPacker(
token_micro_bsz=token_micro_bsz,
bsz_warmup_steps=bsz_warmup_steps,
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
)
def is_full_filled(self) -> bool:
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
def put_item(self, item: dict[str, any]):
if len(item["input_ids"]) == 1:
print("WARNING: EMPTY STRING.")
return
self.buffer.append(item)
def get_token_num_to_request(self):
if self.packer is not None:
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
else:
return self.get_cur_token_micro_bsz()
def get_cur_token_micro_bsz(self):
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
if warmup:
return (
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
else:
return self.token_micro_bsz
def get_micro_batch(self, step) -> any:
"""Get a micro batch from the buffer according to the current step.
Args:
step: the current step.
Returns:
data: a list of samples.
"""
self._step = step
n_token_per_iter = self.get_token_num_to_request()
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
assert cur_token_micro_bsz % n_token_per_iter == 0, (
"The token num to get for each request should be divisible by token micro bsz."
)
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
data = []
for _ in range(n_iter):
samples = self.buffer.get_samples(n_token_per_iter)
if self.packer:
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
data.extend(samples)
self.buffer.flush() # remove the selected samples.
return data
def empty(self) -> bool:
return len(self.buffer) == 0

View File

@@ -0,0 +1,173 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.data_loader import DataLoader
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
class TensorDataset(Dataset):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
self.data_engine = data_engine
self.processor = processor
self.template = template
self.max_samples = max_samples or len(data_engine)
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
def __len__(self):
return min(self.max_samples, len(self.data_engine))
def __getitem__(self, idx):
# Get sample from DataEngine
sample = self.data_engine[idx]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages = None
if "messages" in sample:
messages = sample["messages"]
elif "conversations" in sample:
messages = sample["conversations"]
elif "conversation" in sample:
messages = sample["conversation"]
else:
# Try to find message-like fields (skip _dataset_name)
for key, value in sample.items():
if key.startswith("_"):
continue
if isinstance(value, list) and len(value) > 0:
# Check if it looks like a message list
if isinstance(value[0], dict) and "role" in value[0]:
messages = value
break
if messages is None:
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# Encode messages using template
encoded = self.template.encode_messages(self.tokenizer, messages)
# Convert to tensors
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
}
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
"""Create a real dataset using DataEngine."""
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
# Create processor and template
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
template = QwenTemplate()
# Create tensor dataset
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# Create torch DataLoader
torch_dataloader = TorchDataLoader(
raw_data_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: x,
)
return torch_dataloader, processor, template
class TestDataLoaderNonPackNonDynamic:
"""Test case a) non pack + non dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# Create collator (non-packing)
collator = DefaultCollator(processor=processor, template=template)
# Create DataLoader without batching_queue (non-dynamic)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=1,
batching_queue=None,
)
# Iterate and check results
batches = list(iter(data_loader))
assert len(batches) > 0
# Check first batch
one_batch = batches[0]
micro_batches = one_batch[0]
assert "input_ids" in micro_batches
assert "attention_mask" in micro_batches
assert "labels" in micro_batches
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
class TestDataLoaderNonPackDynamic:
"""Test case b) non pack + dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
collator = DefaultCollator(processor=processor, template=template)
# Create batching queue for dynamic batching
batching_queue = TextBatchingQueue(
token_micro_bsz=120,
buffer_size=8,
)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=4,
batching_queue=batching_queue,
)
# Iterate and check
batches = list(iter(data_loader))
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
assert len(batches) > 0

View File

@@ -0,0 +1,112 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue
def create_sample(length: int):
"""Helper to create a mock sample with a specific token length."""
return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)}
class TestDynamicBatchSizeBuffer:
def test_append_and_token_count(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(20))
assert len(buffer) == 2
assert buffer.total_token_count == 30
def test_get_samples_within_budget(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(10))
buffer.append(create_sample(50)) # This one is large
# Request 25 tokens. Should get the first two (20 tokens total)
samples = buffer.get_samples(max_tokens_per_iteration=25)
assert len(samples) == 2
def test_force_return_first_sample(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(100))
# Even though budget is 50, force=True (default) should return the 100-token sample
samples = buffer.get_samples(max_tokens_per_iteration=50, force=True)
assert len(samples) == 1
assert len(samples[0]["input_ids"]) == 100
def test_flush_removes_used_samples(self):
buffer = DynamicBatchSizeBuffer()
buffer.append(create_sample(10))
buffer.append(create_sample(20))
# Take the first sample
buffer.get_samples(max_tokens_per_iteration=15)
buffer.flush()
assert len(buffer) == 1
assert buffer.total_token_count == 20
# The remaining sample should now be at the start
remaining = buffer.get_samples(max_tokens_per_iteration=50)
assert len(remaining[0]["input_ids"]) == 20
class TestTextBatchingQueue:
def test_is_full_filled(self):
queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2)
queue.put_item(create_sample(10))
assert not queue.is_full_filled() # Only 1 sample, buffer_size=2
queue.put_item(create_sample(10))
assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100)
queue.put_item(create_sample(90))
assert queue.is_full_filled() # Meets both conditions
def test_warmup_logic(self):
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200)
# Step 0: should be init value
assert queue.get_cur_token_micro_bsz() == 200
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
queue._step = 5
assert queue.get_cur_token_micro_bsz() == 600
# Step 11: past warmup
queue._step = 11
assert queue.get_cur_token_micro_bsz() == 1000
def test_get_micro_batch_integration(self):
queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1)
queue.put_item(create_sample(20))
queue.put_item(create_sample(20))
queue.put_item(create_sample(20))
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
# it should take samples up to 50 tokens.
batch = queue.get_micro_batch(step=0)
assert len(batch) == 2
assert queue.empty() is False
batch_2 = queue.get_micro_batch(step=1)
assert len(batch_2) == 1
assert queue.empty() is True