mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[v1&WIP] dataloader init (#9645)
This commit is contained in:
@@ -5,6 +5,7 @@ datasets>=2.16.0,<=4.0.0
|
||||
accelerate>=1.3.0,<=1.11.0
|
||||
peft>=0.14.0,<=0.17.1
|
||||
trl>=0.8.6,<=0.9.6
|
||||
torchdata
|
||||
# gui
|
||||
gradio>=4.38.0,<=5.45.0
|
||||
matplotlib>=3.7.0
|
||||
|
||||
277
src/llamafactory/v1/core/data_loader.py
Normal file
277
src/llamafactory/v1/core/data_loader.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ..utils.batching_queue import BaseBatchingQueue
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.types import Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# base dataloader
|
||||
class DistributedDataloader(StatefulDataLoader):
|
||||
"""Base Distributed DataLoader."""
|
||||
|
||||
dataset: "TorchDataset"
|
||||
sampler: "StatefulDistributedSampler"
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
|
||||
self.sampler.set_epoch(epoch)
|
||||
elif hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: Processor
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# guidlines: fetch until get fixed batchsize.
|
||||
# save state_dict for buffer.
|
||||
# resume with state
|
||||
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
|
||||
def init_dataloader(self) -> None:
|
||||
### init dataloader
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
pass
|
||||
|
||||
def __next__(self) -> any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
|
||||
processor: "Processor"
|
||||
dataloader: "DistributedDataloader"
|
||||
batching_queue: "BaseBatchingQueue"
|
||||
collate_fn: "DataCollator"
|
||||
num_micro_batch: int = 1
|
||||
length: int = 0
|
||||
drop_last: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataloader: any,
|
||||
collate_fn: "DataCollator",
|
||||
num_micro_batch: int = 1,
|
||||
length: int = 0,
|
||||
drop_last: bool = True,
|
||||
batching_queue: Optional["BaseBatchingQueue"] = None,
|
||||
) -> None:
|
||||
self.batching_queue = batching_queue
|
||||
self.num_micro_batch = num_micro_batch
|
||||
self.step = 0
|
||||
self._collate_fn = collate_fn
|
||||
self._dataloader = dataloader
|
||||
self._drop_last = drop_last
|
||||
self._data_iter: Iterator
|
||||
self._resume = False
|
||||
self._batch_data_iter: Generator
|
||||
|
||||
if length > 0:
|
||||
self._length = length
|
||||
elif length == -1:
|
||||
self._length = sys.maxsize
|
||||
else:
|
||||
self._length = len(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if not self._resume:
|
||||
self.step = 0
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
self._resume = False
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here
|
||||
|
||||
def origin_batch_data_generator(self):
|
||||
"""Standard pass-through generator if do not use batching queue."""
|
||||
while True:
|
||||
if self._length > 0 and self.step >= self._length:
|
||||
return
|
||||
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
# split data into micro batches
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
if self.step < self._length:
|
||||
# Restart iterator to fill the requested length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
try:
|
||||
batch = []
|
||||
data = next(self._data_iter)
|
||||
for i in range(0, len(data), self.num_micro_batch):
|
||||
micro_batch = data[i : i + self.num_micro_batch]
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
except StopIteration:
|
||||
return
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"DataLoader origin_batch_data_generator exception: {e}")
|
||||
raise
|
||||
|
||||
def batch_data_generator(self):
|
||||
if self.batching_queue is None:
|
||||
yield from self.origin_batch_data_generator()
|
||||
return
|
||||
|
||||
batch = []
|
||||
|
||||
while True:
|
||||
if self._length and self.step >= self._length:
|
||||
return
|
||||
|
||||
if self.batching_queue.is_full_filled():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
try:
|
||||
processing_item = next(self._data_iter)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration):
|
||||
if self.step < self._length:
|
||||
# call iter until reach length
|
||||
self._data_iter = iter(self._dataloader)
|
||||
processing_item = next(self._data_iter)
|
||||
elif not self._drop_last and not self.batching_queue.empty():
|
||||
while not self.batching_queue.empty():
|
||||
micro_batch = self.batching_queue.get_micro_batch(self.step)
|
||||
if self._collate_fn:
|
||||
micro_batch = self._collate_fn(micro_batch)
|
||||
batch.append(micro_batch)
|
||||
if len(batch) == self.num_micro_batch:
|
||||
yield batch
|
||||
self.step += 1
|
||||
batch = []
|
||||
|
||||
while len(batch) < self.num_micro_batch:
|
||||
padding_batch = copy.deepcopy(micro_batch)
|
||||
padding_batch["is_padded"] = True
|
||||
batch.append(padding_batch)
|
||||
yield batch
|
||||
self.step += 1
|
||||
return
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error(f"DataLoader iter data exception: {e}")
|
||||
raise
|
||||
|
||||
# put processing_item to buffer
|
||||
if isinstance(processing_item, dict):
|
||||
processing_item = [processing_item]
|
||||
|
||||
for item in processing_item:
|
||||
self.batching_queue.put_item(item)
|
||||
|
||||
def state_dict(self):
|
||||
# save state
|
||||
state = self.__dict__.copy()
|
||||
# remove internal fields
|
||||
for k in list(state.keys()):
|
||||
if k.startswith("_"):
|
||||
del state[k]
|
||||
|
||||
# save dataloader state
|
||||
if hasattr(self._dataloader, "state_dict"):
|
||||
state["dataloader_state"] = self._dataloader.state_dict()
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
state["dataloader_state"] = self._dataloader.__getstate__()
|
||||
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy and hasattr(batching_strategy, "state_dict"):
|
||||
state["batching_strategy_state"] = batching_strategy.state_dict()
|
||||
if "batching_strategy" in state:
|
||||
del state["batching_strategy"]
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def load_state_dict(self, state: dict[str, any]):
|
||||
if state["num_micro_batch"] != self.num_micro_batch:
|
||||
logger.warning(
|
||||
f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
|
||||
)
|
||||
del state["num_micro_batch"]
|
||||
self.__dict__.update(state)
|
||||
self._resume = True
|
||||
|
||||
if hasattr(self._dataloader, "load_state_dict"):
|
||||
self._dataloader.load_state_dict(state["dataloader_state"])
|
||||
elif hasattr(self._dataloader, "__getstate__"):
|
||||
self._dataloader.__setstate__(state["dataloader_state"])
|
||||
|
||||
if "batching_strategy_state" in state:
|
||||
batching_strategy = getattr(self, "batching_strategy", None)
|
||||
if batching_strategy:
|
||||
batching_strategy.load_state_dict(state["batching_strategy_state"])
|
||||
del state["batching_strategy_state"]
|
||||
|
||||
self._data_iter = iter(self._dataloader)
|
||||
self._batch_data_iter = self.batch_data_generator()
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
if hasattr(self._dataloader, "set_epoch"):
|
||||
self._dataloader.set_epoch(epoch)
|
||||
@@ -12,36 +12,108 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ...utils.types import Processor, Tensor, TorchDataset
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from ....extras.constants import IGNORE_INDEX
|
||||
from ...plugins.data_plugins.template import Template
|
||||
from ...utils.types import Processor, Tensor
|
||||
|
||||
|
||||
def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor": # FIXME move to utils
|
||||
"""Convert sequence lengths to cumulative sequence lengths."""
|
||||
return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
|
||||
|
||||
|
||||
class DataCollator:
|
||||
"""Default Data collator."""
|
||||
|
||||
def __init__(self, processor: Processor) -> None:
|
||||
self.processor = processor
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
|
||||
def __post_init__(self):
|
||||
# callback for text tokenizer
|
||||
self.tokenizer = self.processor.tokenizer if hasattr(self.processor, "tokenizer") else self.processor
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
|
||||
"""Collate features into a batch."""
|
||||
for feature in features:
|
||||
pass
|
||||
batch = defaultdict(list)
|
||||
|
||||
# batching features
|
||||
for feature in features:
|
||||
for key in feature.keys():
|
||||
batch[key].append(feature[key])
|
||||
|
||||
for key in batch.keys():
|
||||
# process padding features
|
||||
if key in ["input_ids", "attention_mask", "position_ids"]:
|
||||
padding_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=padding_value)
|
||||
elif key in ["labels"]:
|
||||
batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
|
||||
else:
|
||||
batch[key] = default_collate(batch[key])
|
||||
|
||||
return batch
|
||||
# sft: messages
|
||||
# dpo: chosen_messages, rejected_messages
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Default DataLoader."""
|
||||
@dataclass
|
||||
class DefaultCollator(DataCollator):
|
||||
"""Example for now."""
|
||||
|
||||
def __init__(self, dataset: TorchDataset) -> None:
|
||||
self.dataset = dataset
|
||||
# 1. Init stateful dataloader (tokenize)
|
||||
# 2. Add to buffer (2 * max seq len per device)
|
||||
# 3. Yield batch indexes (micro batch * grad acc)
|
||||
# a ) non pack + non dynamic
|
||||
# b ) non pack + dynamic
|
||||
# c ) pack + non dynamic
|
||||
# d ) pack + dynamic
|
||||
processor: "Processor" # processor name -> map to encode_messages function
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, messages: list[list[dict[str, Any]]]) -> dict[str, Tensor]:
|
||||
features = []
|
||||
|
||||
# Check if data is already tokenized (contains input_ids)
|
||||
if messages and isinstance(messages[0], dict) and "input_ids" in messages[0]:
|
||||
for feature in messages:
|
||||
if not isinstance(feature, dict):
|
||||
raise ValueError(f"Expected dict but got {type(feature)}")
|
||||
tensor_feature = {
|
||||
k: torch.tensor(v, dtype=torch.long) if not isinstance(v, torch.Tensor) else v
|
||||
for k, v in feature.items()
|
||||
}
|
||||
features.append(tensor_feature)
|
||||
else:
|
||||
# raw messages need to be encoded
|
||||
for message in messages:
|
||||
encoded_message = self.template.encode_messages(self.tokenizer, message)
|
||||
encoded_message = {k: torch.tensor(v, dtype=torch.long) for k, v in encoded_message.items()}
|
||||
features.append(encoded_message)
|
||||
|
||||
return super().__call__(features)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseCollator(DataCollator):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorWithPacking(DefaultCollator):
|
||||
"""Data collator with packing."""
|
||||
|
||||
processor: "Processor"
|
||||
template: "Template"
|
||||
|
||||
def __call__(self, features: Sequence[dict[str, "torch.Tensor"]]) -> dict[str, "torch.Tensor"]:
|
||||
seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
|
||||
batch = {"cu_seqlens": len2culen(seqlens)}
|
||||
for input_name in features[0].keys():
|
||||
if input_name in ("input_ids", "attention_mask", "labels"):
|
||||
batch[input_name] = torch.cat([feature[input_name] for feature in features])
|
||||
else:
|
||||
batch[input_name] = default_collate([feature[input_name] for feature in features])
|
||||
|
||||
return batch
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,5 +23,112 @@ class Template:
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: "dict[str, str]") -> str:
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
|
||||
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
220
src/llamafactory/v1/utils/batching_queue.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Bytedance's VeOmni library.
|
||||
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DynamicBatchSizeBuffer:
|
||||
"""A buffer to store samples for dynamic batch size."""
|
||||
|
||||
def __init__(self):
|
||||
self._buffer: list[dict[str, any]] = []
|
||||
self._buffer_sample_lengths: list[int] = []
|
||||
self._deleted_indices: set[int] = set()
|
||||
self._current_index: int = 0
|
||||
self._total_token_count: int = 0
|
||||
|
||||
def append(self, item: dict[str, any]) -> None:
|
||||
"""Append a sample to the buffer.
|
||||
|
||||
Args:
|
||||
item: A sample to append to the buffer.
|
||||
The sample should be a dict with the following keys:
|
||||
- input_ids: torch.Tensor of shape (seq_len, )
|
||||
- attention_mask: torch.Tensor of shape (seq_len, )
|
||||
"""
|
||||
self._buffer.append(item)
|
||||
sample_length = int(item["attention_mask"].sum().item())
|
||||
self._buffer_sample_lengths.append(sample_length)
|
||||
self._total_token_count += sample_length
|
||||
|
||||
def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]:
|
||||
"""Get samples from the buffer that fit within the token budget.
|
||||
|
||||
Args:
|
||||
max_tokens_per_iteration: Maximum number of tokens to retrieve.
|
||||
force: If True, the first available sample will be returned even
|
||||
if it exceeds the token budget.
|
||||
|
||||
Returns:
|
||||
A list of samples that fit within the token budget.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no samples are found (should not happen in normal operation).
|
||||
"""
|
||||
cum_seq_len = 0
|
||||
samples = []
|
||||
|
||||
while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration:
|
||||
if self._current_index in self._deleted_indices:
|
||||
self._current_index += 1
|
||||
continue
|
||||
|
||||
seq_len = self._buffer_sample_lengths[self._current_index]
|
||||
remaining_tokens = max_tokens_per_iteration - cum_seq_len
|
||||
|
||||
# Check if we can add this sample
|
||||
can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens)
|
||||
|
||||
if can_add:
|
||||
cum_seq_len += seq_len
|
||||
samples.append(self._buffer[self._current_index])
|
||||
self._deleted_indices.add(self._current_index)
|
||||
|
||||
self._current_index += 1
|
||||
|
||||
assert len(samples) > 0, "No samples found in buffer"
|
||||
return samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in the buffer."""
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def total_token_count(self) -> int:
|
||||
"""Return the total number of tokens in the buffer."""
|
||||
return self._total_token_count
|
||||
|
||||
def flush(self) -> None:
|
||||
tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices)
|
||||
self._total_token_count -= tokens_to_remove
|
||||
|
||||
buffer_length = len(self._buffer)
|
||||
self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices]
|
||||
self._buffer_sample_lengths = [
|
||||
self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices
|
||||
]
|
||||
|
||||
self._current_index = 0
|
||||
self._deleted_indices.clear()
|
||||
|
||||
|
||||
class BaseBatchingQueue(ABC):
|
||||
"""Base class for batching queue."""
|
||||
|
||||
@abstractmethod
|
||||
def is_full_filled(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `is_full_filled`")
|
||||
|
||||
@abstractmethod
|
||||
def put_item(self, item: dict[str, any]) -> None:
|
||||
raise NotImplementedError("Subclasses must implement `put_item`")
|
||||
|
||||
@abstractmethod
|
||||
def get_micro_batch(self, step: int) -> list[dict[str, any]]:
|
||||
raise NotImplementedError("Subclasses must implement `get_micro_batch`")
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement `empty`")
|
||||
|
||||
|
||||
class IdentityPacker:
|
||||
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
|
||||
|
||||
def __call__(self, samples):
|
||||
return samples
|
||||
|
||||
def get_token_num_to_request(self, cur_step, warmup):
|
||||
return (
|
||||
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
|
||||
+ self.bsz_warmup_init_mbtoken
|
||||
if warmup
|
||||
else self.token_micro_bsz
|
||||
)
|
||||
|
||||
|
||||
class TextBatchingQueue(BaseBatchingQueue):
|
||||
"""Batching text queue for text data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_micro_bsz,
|
||||
buffer_size: int = 500,
|
||||
bsz_warmup_steps: int = -1,
|
||||
bsz_warmup_init_mbtoken: int = 200,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._step = 0
|
||||
self.token_micro_bsz = token_micro_bsz
|
||||
self.bsz_warmup_steps = bsz_warmup_steps
|
||||
self.buffer_size = buffer_size # minimum samples in buffer
|
||||
self.buffer = DynamicBatchSizeBuffer()
|
||||
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args
|
||||
assert self.bsz_warmup_init_mbtoken >= 0
|
||||
|
||||
self.packer = IdentityPacker(
|
||||
token_micro_bsz=token_micro_bsz,
|
||||
bsz_warmup_steps=bsz_warmup_steps,
|
||||
bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
|
||||
)
|
||||
|
||||
def is_full_filled(self) -> bool:
|
||||
return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz
|
||||
|
||||
def put_item(self, item: dict[str, any]):
|
||||
if len(item["input_ids"]) == 1:
|
||||
print("WARNING: EMPTY STRING.")
|
||||
return
|
||||
self.buffer.append(item)
|
||||
|
||||
def get_token_num_to_request(self):
|
||||
if self.packer is not None:
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
return self.packer.get_token_num_to_request(self._step, warmup=warmup)
|
||||
else:
|
||||
return self.get_cur_token_micro_bsz()
|
||||
|
||||
def get_cur_token_micro_bsz(self):
|
||||
warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
|
||||
if warmup:
|
||||
return (
|
||||
self.token_micro_bsz - self.bsz_warmup_init_mbtoken
|
||||
) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
|
||||
else:
|
||||
return self.token_micro_bsz
|
||||
|
||||
def get_micro_batch(self, step) -> any:
|
||||
"""Get a micro batch from the buffer according to the current step.
|
||||
|
||||
Args:
|
||||
step: the current step.
|
||||
|
||||
Returns:
|
||||
data: a list of samples.
|
||||
"""
|
||||
self._step = step
|
||||
n_token_per_iter = self.get_token_num_to_request()
|
||||
cur_token_micro_bsz = self.get_cur_token_micro_bsz()
|
||||
assert cur_token_micro_bsz % n_token_per_iter == 0, (
|
||||
"The token num to get for each request should be divisible by token micro bsz."
|
||||
)
|
||||
n_iter = int(cur_token_micro_bsz // n_token_per_iter)
|
||||
data = []
|
||||
for _ in range(n_iter):
|
||||
samples = self.buffer.get_samples(n_token_per_iter)
|
||||
if self.packer:
|
||||
samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
|
||||
data.extend(samples)
|
||||
self.buffer.flush() # remove the selected samples.
|
||||
return data
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.buffer) == 0
|
||||
173
tests_v1/core/test_data_loader.py
Normal file
173
tests_v1/core/test_data_loader.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
||||
|
||||
Tests the 4 scenarios:
|
||||
a) non pack + non dynamic.
|
||||
b) non pack + dynamic.
|
||||
c) pack + non dynamic.
|
||||
d) pack + dynamic.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader as TorchDataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.data_loader import DataLoader
|
||||
from llamafactory.v1.core.trainer_utils.data_collator import (
|
||||
DefaultCollator,
|
||||
)
|
||||
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
|
||||
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
"""Wrapper dataset that converts DataEngine samples to tensor format."""
|
||||
|
||||
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
||||
self.data_engine = data_engine
|
||||
self.processor = processor
|
||||
self.template = template
|
||||
self.max_samples = max_samples or len(data_engine)
|
||||
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
|
||||
def __len__(self):
|
||||
return min(self.max_samples, len(self.data_engine))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get sample from DataEngine
|
||||
sample = self.data_engine[idx]
|
||||
|
||||
# Extract messages from sample
|
||||
# DataEngine returns samples with format like {"messages": [...], ...}
|
||||
# For llamafactory/v1-sft-demo, the format should have "messages" field
|
||||
messages = None
|
||||
if "messages" in sample:
|
||||
messages = sample["messages"]
|
||||
elif "conversations" in sample:
|
||||
messages = sample["conversations"]
|
||||
elif "conversation" in sample:
|
||||
messages = sample["conversation"]
|
||||
else:
|
||||
# Try to find message-like fields (skip _dataset_name)
|
||||
for key, value in sample.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
# Check if it looks like a message list
|
||||
if isinstance(value[0], dict) and "role" in value[0]:
|
||||
messages = value
|
||||
break
|
||||
|
||||
if messages is None:
|
||||
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
||||
|
||||
# Encode messages using template
|
||||
encoded = self.template.encode_messages(self.tokenizer, messages)
|
||||
|
||||
# Convert to tensors
|
||||
return {
|
||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
||||
"""Create a real dataset using DataEngine."""
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
|
||||
# Create processor and template
|
||||
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
template = QwenTemplate()
|
||||
|
||||
# Create tensor dataset
|
||||
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
||||
|
||||
# Create torch DataLoader
|
||||
torch_dataloader = TorchDataLoader(
|
||||
raw_data_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
|
||||
return torch_dataloader, processor, template
|
||||
|
||||
|
||||
class TestDataLoaderNonPackNonDynamic:
|
||||
"""Test case a) non pack + non dynamic."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test DataLoader without packing and without dynamic batching."""
|
||||
# Create real dataset
|
||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
|
||||
# Create collator (non-packing)
|
||||
collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# Create DataLoader without batching_queue (non-dynamic)
|
||||
data_loader = DataLoader(
|
||||
dataloader=torch_dataloader,
|
||||
collate_fn=collator,
|
||||
num_micro_batch=1,
|
||||
batching_queue=None,
|
||||
)
|
||||
|
||||
# Iterate and check results
|
||||
batches = list(iter(data_loader))
|
||||
assert len(batches) > 0
|
||||
|
||||
# Check first batch
|
||||
one_batch = batches[0]
|
||||
micro_batches = one_batch[0]
|
||||
assert "input_ids" in micro_batches
|
||||
assert "attention_mask" in micro_batches
|
||||
assert "labels" in micro_batches
|
||||
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
||||
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
||||
|
||||
|
||||
class TestDataLoaderNonPackDynamic:
|
||||
"""Test case b) non pack + dynamic."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test DataLoader without packing but with dynamic batching."""
|
||||
# Create real dataset
|
||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# Create batching queue for dynamic batching
|
||||
batching_queue = TextBatchingQueue(
|
||||
token_micro_bsz=120,
|
||||
buffer_size=8,
|
||||
)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataloader=torch_dataloader,
|
||||
collate_fn=collator,
|
||||
num_micro_batch=4,
|
||||
batching_queue=batching_queue,
|
||||
)
|
||||
|
||||
# Iterate and check
|
||||
batches = list(iter(data_loader))
|
||||
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
||||
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
||||
assert len(batches) > 0
|
||||
112
tests_v1/utils/test_batching_queue.py
Normal file
112
tests_v1/utils/test_batching_queue.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue
|
||||
|
||||
|
||||
def create_sample(length: int):
|
||||
"""Helper to create a mock sample with a specific token length."""
|
||||
return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)}
|
||||
|
||||
|
||||
class TestDynamicBatchSizeBuffer:
|
||||
def test_append_and_token_count(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(20))
|
||||
|
||||
assert len(buffer) == 2
|
||||
assert buffer.total_token_count == 30
|
||||
|
||||
def test_get_samples_within_budget(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(50)) # This one is large
|
||||
|
||||
# Request 25 tokens. Should get the first two (20 tokens total)
|
||||
samples = buffer.get_samples(max_tokens_per_iteration=25)
|
||||
assert len(samples) == 2
|
||||
|
||||
def test_force_return_first_sample(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(100))
|
||||
|
||||
# Even though budget is 50, force=True (default) should return the 100-token sample
|
||||
samples = buffer.get_samples(max_tokens_per_iteration=50, force=True)
|
||||
assert len(samples) == 1
|
||||
assert len(samples[0]["input_ids"]) == 100
|
||||
|
||||
def test_flush_removes_used_samples(self):
|
||||
buffer = DynamicBatchSizeBuffer()
|
||||
buffer.append(create_sample(10))
|
||||
buffer.append(create_sample(20))
|
||||
|
||||
# Take the first sample
|
||||
buffer.get_samples(max_tokens_per_iteration=15)
|
||||
buffer.flush()
|
||||
|
||||
assert len(buffer) == 1
|
||||
assert buffer.total_token_count == 20
|
||||
# The remaining sample should now be at the start
|
||||
remaining = buffer.get_samples(max_tokens_per_iteration=50)
|
||||
assert len(remaining[0]["input_ids"]) == 20
|
||||
|
||||
|
||||
class TestTextBatchingQueue:
|
||||
def test_is_full_filled(self):
|
||||
queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2)
|
||||
|
||||
queue.put_item(create_sample(10))
|
||||
assert not queue.is_full_filled() # Only 1 sample, buffer_size=2
|
||||
|
||||
queue.put_item(create_sample(10))
|
||||
assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100)
|
||||
|
||||
queue.put_item(create_sample(90))
|
||||
assert queue.is_full_filled() # Meets both conditions
|
||||
|
||||
def test_warmup_logic(self):
|
||||
# token_micro_bsz=1000, starts at 200, reaches 1000 at step 10
|
||||
queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200)
|
||||
|
||||
# Step 0: should be init value
|
||||
assert queue.get_cur_token_micro_bsz() == 200
|
||||
|
||||
# Step 5: halfway through warmup (200 + (800 * 5/10)) = 600
|
||||
queue._step = 5
|
||||
assert queue.get_cur_token_micro_bsz() == 600
|
||||
|
||||
# Step 11: past warmup
|
||||
queue._step = 11
|
||||
assert queue.get_cur_token_micro_bsz() == 1000
|
||||
|
||||
def test_get_micro_batch_integration(self):
|
||||
queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1)
|
||||
queue.put_item(create_sample(20))
|
||||
queue.put_item(create_sample(20))
|
||||
queue.put_item(create_sample(20))
|
||||
|
||||
# At step 0 (warmup not triggered as bsz_warmup_steps is -1 default),
|
||||
# it should take samples up to 50 tokens.
|
||||
batch = queue.get_micro_batch(step=0)
|
||||
|
||||
assert len(batch) == 2
|
||||
assert queue.empty() is False
|
||||
|
||||
batch_2 = queue.get_micro_batch(step=1)
|
||||
assert len(batch_2) == 1
|
||||
assert queue.empty() is True
|
||||
Reference in New Issue
Block a user